from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
Bhargav Sai Gogineni
email : bgoginen@iu.edu
Harini Mohanasundaram
email : harmohan@iu.edu
Tarika Sadey
email : tsadey@iu.edu
Suraj Gupta Gudla
email : surgudla@iu.edu
One of the biggest challenges in the field of computer vision is Image Classification and detection which has varied applications starting from the field of medicine, space object detection and so on. The aim of this project is to perform image classification analysis on a dataset to differentiate between cat and dog images using classification and regression models. We first preprocess the images from the image catalogue using different function metrics and using the RGB intensity value arrays we feed them into classifiers to predict labels(cat/dog) and bounding boxes for which we will implement baseline model Logistic Regression as our classifier and optimize using stochastic gradient descent at an adaptive learning rate, homegrown logistic regression and linear regression to predict class, bounding boxes and loss function attributes using both sklearn and pytorch. We then plan to extend our model training by implementing Convolutional Neural Network(CNN) model for single object detection using Pytorch and use different performance metrics such as RMSE, MSE and accuracy to measure our model performance.
Our aim for this project is to build object detection pipelines using Python, OpenCV, SKLearn, and PyTorch to detect cat and dog images.We import the image catalogue data, perform Exploratory data analysis on it, derrive some metrics and baseline models on the data.In order to create a detector, we will first have to preprocess the images to be all of the same shapes, take their RGB intensity values and flatten them from a 3D array to 2D. Then we will feed this array into a linear classifier and a linear regressor to predict labels and bounding boxes.
Build an SKLearn model for image classification and another model for regression Implement a Homegrown Logistic Regression model. Extend the loss function from CXE to CXE + MSE Build a baseline pipeline in PyTorch to object classification and object localization
Build a convolutional neural network network for single object classifier and detector.
The data set consists of about 12,966 RGB images of cats and dogs with varying shapes and aspect ratios. The image bounding box coordinates are stored in a .csv file which contain image description, box coordinate descriptions along with some required attributes. We define some of the data attributes as below:
from collections import Counter
import glob
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from PIL import Image
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import SGDClassifier, SGDRegressor
from sklearn.metrics import accuracy_score, mean_squared_error, roc_auc_score
from sklearn.model_selection import train_test_split
import tarfile
from tqdm.notebook import tqdm
import warnings
def extract_tar(file, path):
"""
function to extract tar.gz files to specified location
Args:
file (str): path where the file is located
path (str): path where you want to extract
"""
with tarfile.open(file) as tar:
files_extracted = 0
for member in tqdm(tar.getmembers()):
if os.path.isfile(path + member.name[1:]):
continue
else:
tar.extract(member, path)
files_extracted += 1
tar.close()
if files_extracted < 3:
print('Files already exist')
path = 'images/'
extract_tar('/content/drive/MyDrive/AML_Project/cadod.tar.gz', path)
df = pd.read_csv('/content/drive/MyDrive/AML_Project/cadod.csv')
df.head()
| ImageID | Source | LabelName | Confidence | XMin | XMax | YMin | YMax | IsOccluded | IsTruncated | IsGroupOf | IsDepiction | IsInside | XClick1X | XClick2X | XClick3X | XClick4X | XClick1Y | XClick2Y | XClick3Y | XClick4Y | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0000b9fcba019d36 | xclick | /m/0bt9lr | 1 | 0.165000 | 0.903750 | 0.268333 | 0.998333 | 1 | 1 | 0 | 0 | 0 | 0.636250 | 0.903750 | 0.748750 | 0.165000 | 0.268333 | 0.506667 | 0.998333 | 0.661667 |
| 1 | 0000cb13febe0138 | xclick | /m/0bt9lr | 1 | 0.000000 | 0.651875 | 0.000000 | 0.999062 | 1 | 1 | 0 | 0 | 0 | 0.312500 | 0.000000 | 0.317500 | 0.651875 | 0.000000 | 0.410882 | 0.999062 | 0.999062 |
| 2 | 0005a9520eb22c19 | xclick | /m/0bt9lr | 1 | 0.094167 | 0.611667 | 0.055626 | 0.998736 | 1 | 1 | 0 | 0 | 0 | 0.487500 | 0.611667 | 0.243333 | 0.094167 | 0.055626 | 0.226296 | 0.998736 | 0.305942 |
| 3 | 0006303f02219b07 | xclick | /m/0bt9lr | 1 | 0.000000 | 0.999219 | 0.000000 | 0.998824 | 1 | 1 | 0 | 0 | 0 | 0.508594 | 0.999219 | 0.000000 | 0.478906 | 0.000000 | 0.375294 | 0.720000 | 0.998824 |
| 4 | 00064d23bf997652 | xclick | /m/0bt9lr | 1 | 0.240938 | 0.906183 | 0.000000 | 0.694286 | 0 | 0 | 0 | 0 | 0 | 0.678038 | 0.906183 | 0.240938 | 0.522388 | 0.000000 | 0.370000 | 0.424286 | 0.694286 |
df.columns
Index(['ImageID', 'Source', 'LabelName', 'Confidence', 'XMin', 'XMax', 'YMin',
'YMax', 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction',
'IsInside', 'XClick1X', 'XClick2X', 'XClick3X', 'XClick4X', 'XClick1Y',
'XClick2Y', 'XClick3Y', 'XClick4Y'],
dtype='object')
print(f"There are a total of {len(glob.glob1(path, '*.jpg'))} images")
There are a total of 12966 images
print(f"The total size is {os.path.getsize(path)/1000} MB")
The total size is 1101.824 MB
df.shape
(12966, 21)
Replace LabelName with human readable labels
df.LabelName.replace({'/m/01yrx':'cat', '/m/0bt9lr':'dog'}, inplace=True)
df.LabelName.value_counts()
dog 6855 cat 6111 Name: LabelName, dtype: int64
df.LabelName.value_counts().plot(kind='bar')
plt.title('Image Class Count')
plt.show()
df.describe()
| Confidence | XMin | XMax | YMin | YMax | IsOccluded | IsTruncated | IsGroupOf | IsDepiction | IsInside | XClick1X | XClick2X | XClick3X | XClick4X | XClick1Y | XClick2Y | XClick3Y | XClick4Y | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 12966.0 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 | 12966.000000 |
| mean | 1.0 | 0.099437 | 0.901750 | 0.088877 | 0.945022 | 0.464754 | 0.738470 | 0.013651 | 0.045427 | 0.001157 | 0.390356 | 0.424582 | 0.494143 | 0.506689 | 0.275434 | 0.447448 | 0.641749 | 0.582910 |
| std | 0.0 | 0.113023 | 0.111468 | 0.097345 | 0.081500 | 0.499239 | 0.440011 | 0.118019 | 0.209354 | 0.040229 | 0.358313 | 0.441751 | 0.405033 | 0.462281 | 0.415511 | 0.401580 | 0.448054 | 0.403454 |
| min | 1.0 | 0.000000 | 0.408125 | 0.000000 | 0.451389 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 |
| 25% | 1.0 | 0.000000 | 0.830625 | 0.000000 | 0.910000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.221293 | 0.096875 | 0.285071 | 0.130000 | 0.024323 | 0.218333 | 0.405817 | 0.400000 |
| 50% | 1.0 | 0.061250 | 0.941682 | 0.059695 | 0.996875 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.435625 | 0.415625 | 0.531919 | 0.623437 | 0.146319 | 0.480839 | 0.825000 | 0.646667 |
| 75% | 1.0 | 0.167500 | 0.998889 | 0.144853 | 0.999062 | 1.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.609995 | 0.820000 | 0.787500 | 0.917529 | 0.561323 | 0.729069 | 0.998042 | 0.882500 |
| max | 1.0 | 0.592500 | 1.000000 | 0.587088 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 0.999375 | 0.999375 | 1.000000 | 0.999375 | 0.999375 | 0.999375 | 1.000000 | 0.999375 |
df.IsOccluded.value_counts().plot(kind='bar')
plt.title('Image Occlusion')
plt.xlabel("Not Occluded = 0, Occluded = 1, Unsure = -1")
plt.show()
count_unique = Counter(df['IsOccluded'])
print(count_unique)
Counter({0: 6934, 1: 6029, -1: 3})
df.IsTruncated.value_counts().plot(kind='bar')
plt.title('Image Truncation')
plt.xlabel("Not Truncated = 0, Truncated = 1, Unsure = -1")
plt.show()
count_unique = Counter(df['IsTruncated'])
print(count_unique)
Counter({1: 9578, 0: 3385, -1: 3})
df.IsGroupOf.value_counts().plot(kind='bar')
plt.title('Grouped Images')
plt.xlabel("Is Not Part of Group = 0, Is Part of Group = 1, Unsure = -1")
plt.show()
count_unique = Counter(df['IsGroupOf'])
print(count_unique)
Counter({0: 12783, 1: 180, -1: 3})
df.IsDepiction.value_counts().plot(kind='bar')
plt.title('Image Depiction')
plt.xlabel("Not Depicted = 0, Depicted = 1, Unsure = -1")
plt.show()
count_unique = Counter(df['IsDepiction'])
print(count_unique)
Counter({0: 12371, 1: 592, -1: 3})
df.IsInside.value_counts().plot(kind='bar')
plt.title('Image Inside/Outside')
plt.xlabel("Not Inside = 0, Inside = 1, Unsure = -1")
plt.show()
count_unique = Counter(df['IsInside'])
print(count_unique)
Counter({0: 12945, 1: 18, -1: 3})
# plot random 6 images
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=False, sharey=False,figsize=(15,10))
ax = ax.flatten()
for i,j in enumerate(np.random.choice(df.shape[0], size=6, replace=False)):
img = mpimg.imread(path + df.ImageID.values[j] + '.jpg')
h, w = img.shape[:2]
coords = df.iloc[j,4:8]
ax[i].imshow(img)
ax[i].set_title(df.LabelName[j])
ax[i].add_patch(plt.Rectangle((coords[0]*w, coords[2]*h),
coords[1]*w-coords[0]*w, coords[3]*h-coords[2]*h,
edgecolor='red', facecolor='none'))
plt.tight_layout()
plt.show()
Go through all images and record the shape of the image in pixels and the memory size
img_shape = []
img_size = np.zeros((df.shape[0], 1))
for i,f in enumerate(tqdm(glob.glob1(path, '*.jpg'))):
file = path+'/'+f
img = Image.open(file)
img_shape.append(f"{img.size[0]}x{img.size[1]}")
img_size[i] += os.path.getsize(file)
Count all the different image shapes
img_shape_count = Counter(img_shape)
# create a dataframe for image shapes
img_df = pd.DataFrame(set(img_shape_count.items()), columns=['img_shape','img_count'])
img_df.shape
(594, 2)
img_df.head(5)
| img_shape | img_count | |
|---|---|---|
| 0 | 512x321 | 5 |
| 1 | 330x512 | 3 |
| 2 | 512x381 | 14 |
| 3 | 512x269 | 2 |
| 4 | 512x397 | 8 |
There are a ton of different image shapes. Let's narrow this down by getting a sum of any image shape that has a cout less than 100 and put that in a category called other
img_df = img_df.append({'img_shape': 'other','img_count': img_df[img_df.img_count < 100].img_count.sum()},
ignore_index=True)
Drop all image shapes
img_df = img_df[img_df.img_count >= 100]
Check if the count sum matches the number of images
img_df.img_count.sum() == df.shape[0]
True
Plot
img_df.sort_values('img_count', inplace=True)
img_df.plot(x='img_shape', y='img_count', kind='barh', figsize=(8,8), legend=False)
plt.title('Image Shape Counts')
plt.show()
# convert to megabytes
img_size = img_size / 1000
fig, ax = plt.subplots(1, 2, figsize=(15,5))
fig.suptitle('Image Size Distribution')
ax[0].hist(img_size, bins=50)
ax[0].set_title('Histogram')
ax[0].set_xlabel('Image Size (MB)')
ax[1].boxplot(img_size, vert=False, widths=0.5)
ax[1].set_title('Boxplot')
ax[1].set_xlabel('Image Size (MB)')
ax[1].set_ylabel('Images')
plt.show()
!mkdir -p images/resized
%%time
# resize image and save, convert to numpy
img_arr = np.zeros((df.shape[0],32*32*3)) # initialize np.array
for i, f in enumerate(tqdm(df.ImageID)):
img = Image.open(path+f+'.jpg')
img_resized = img.resize((32,32))
img_resized.save("images/resized/"+f+'.jpg', "JPEG", optimize=True)
img_arr[i] = np.asarray(img_resized, dtype=np.uint8).flatten()
CPU times: user 1min 4s, sys: 2.43 s, total: 1min 7s Wall time: 1min 6s
df["ImageArray"]=[img for img in img_arr]
df.head()
| ImageID | Source | LabelName | Confidence | XMin | XMax | YMin | YMax | IsOccluded | IsTruncated | IsGroupOf | IsDepiction | IsInside | XClick1X | XClick2X | XClick3X | XClick4X | XClick1Y | XClick2Y | XClick3Y | XClick4Y | ImageArray | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0000b9fcba019d36 | xclick | dog | 1 | 0.165000 | 0.903750 | 0.268333 | 0.998333 | 1 | 1 | 0 | 0 | 0 | 0.636250 | 0.903750 | 0.748750 | 0.165000 | 0.268333 | 0.506667 | 0.998333 | 0.661667 | [59.0, 128.0, 182.0, 68.0, 131.0, 184.0, 68.0,... |
| 1 | 0000cb13febe0138 | xclick | dog | 1 | 0.000000 | 0.651875 | 0.000000 | 0.999062 | 1 | 1 | 0 | 0 | 0 | 0.312500 | 0.000000 | 0.317500 | 0.651875 | 0.000000 | 0.410882 | 0.999062 | 0.999062 | [89.0, 89.0, 87.0, 93.0, 93.0, 91.0, 140.0, 13... |
| 2 | 0005a9520eb22c19 | xclick | dog | 1 | 0.094167 | 0.611667 | 0.055626 | 0.998736 | 1 | 1 | 0 | 0 | 0 | 0.487500 | 0.611667 | 0.243333 | 0.094167 | 0.055626 | 0.226296 | 0.998736 | 0.305942 | [94.0, 100.0, 90.0, 90.0, 96.0, 86.0, 80.0, 85... |
| 3 | 0006303f02219b07 | xclick | dog | 1 | 0.000000 | 0.999219 | 0.000000 | 0.998824 | 1 | 1 | 0 | 0 | 0 | 0.508594 | 0.999219 | 0.000000 | 0.478906 | 0.000000 | 0.375294 | 0.720000 | 0.998824 | [132.0, 101.0, 80.0, 149.0, 118.0, 95.0, 161.0... |
| 4 | 00064d23bf997652 | xclick | dog | 1 | 0.240938 | 0.906183 | 0.000000 | 0.694286 | 0 | 0 | 0 | 0 | 0 | 0.678038 | 0.906183 | 0.240938 | 0.522388 | 0.000000 | 0.370000 | 0.424286 | 0.694286 | [1.0, 44.0, 8.0, 0.0, 43.0, 8.0, 1.0, 46.0, 11... |
img_arr[0]
array([ 59., 128., 182., ..., 59., 141., 203.])
Plot the resized and filtered images
# plot random 6 images
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=False, sharey=False,figsize=(15,10))
ax = ax.flatten()
for i,j in enumerate(np.random.choice(df.shape[0], size=6, replace=False)):
img = mpimg.imread(path+'/resized/'+df.ImageID.values[j]+'.jpg')
h, w = img.shape[:2]
coords = df.iloc[j,4:8]
ax[i].imshow(img)
ax[i].set_title(df.iloc[j,2])
ax[i].add_patch(plt.Rectangle((coords[0]*w, coords[2]*h),
coords[1]*w-coords[0]*w, coords[3]*h-coords[2]*h,
edgecolor='red', facecolor='none'))
plt.tight_layout()
plt.show()
# encode labels
df['Label'] = (df.LabelName == 'dog').astype(np.uint8)
# plot first 6 images
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=False, sharey=False,figsize=(15,10))
ax = ax.flatten()
for i,j in enumerate(df.index[5584:5590].to_numpy()):
img = mpimg.imread(path + df.ImageID.values[j] + '.jpg')
h, w = img.shape[:2]
coords = df.iloc[j,4:8]
ax[i].imshow(img)
ax[i].set_title(df.LabelName[j])
ax[i].add_patch(plt.Rectangle((coords[0]*w, coords[2]*h),
coords[1]*w-coords[0]*w, coords[3]*h-coords[2]*h,
edgecolor='red', facecolor='none'))
plt.tight_layout()
plt.show()
# plot the same 6 images shown above after transforamtion
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=False, sharey=False,figsize=(15,10))
ax = ax.flatten()
for i,j in enumerate(df.index[5584:5590].to_numpy()):
img = mpimg.imread(path+'/resized/'+df.ImageID.values[j]+'.jpg')
h, w = img.shape[:2]
coords = df.iloc[j,4:8]
ax[i].imshow(img)
ax[i].set_title(df.iloc[j,2])
ax[i].add_patch(plt.Rectangle((coords[0]*w, coords[2]*h),
coords[1]*w-coords[0]*w, coords[3]*h-coords[2]*h,
edgecolor='red', facecolor='none'))
plt.tight_layout()
plt.show()
mkdir -p data
np.save('data/img.npy', img_arr.astype(np.uint8))
np.save('data/y_label.npy', df.Label.values)
np.save('data/y_bbox.npy', df[['XMin', 'YMin', 'XMax', 'YMax']].values.astype(np.float32))
X = np.load('data/img.npy', allow_pickle=True)
y_label = np.load('data/y_label.npy', allow_pickle=True)
y_bbox = np.load('data/y_bbox.npy', allow_pickle=True)
df.columns
Index(['ImageID', 'Source', 'LabelName', 'Confidence', 'XMin', 'XMax', 'YMin',
'YMax', 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction',
'IsInside', 'XClick1X', 'XClick2X', 'XClick3X', 'XClick4X', 'XClick1Y',
'XClick2Y', 'XClick3Y', 'XClick4Y', 'ImageArray', 'Label'],
dtype='object')
idx_to_label = {1:'dog', 0:'cat'} # encoder
Double check that it loaded correctly
# plot random 6 images
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=False, sharey=False,figsize=(15,10))
ax = ax.flatten()
for i,j in enumerate(np.random.choice(X.shape[0], size=6, replace=False)):
coords = y_bbox[j] * 32
ax[i].imshow(X[j].reshape(32,32,3))
ax[i].set_title(idx_to_label[y_label[j]])
ax[i].add_patch(plt.Rectangle((coords[0], coords[1]),
coords[2]-coords[0], coords[3]-coords[1],
edgecolor='red', facecolor='none'))
plt.tight_layout()
plt.show()
Create training and testing sets
X_train, X_test, y_train, y_test_label = train_test_split(X, y_label, test_size=0.01, random_state=27)
I'm choosing SGDClassifier because the data is large and I want to be able to perform stochastic gradient descent and also its ability to early stop. With this many parameters, a model can easily overfit so it's important to try and find the point of where it begins to overfit and stop for optimal results.
%%time
model = SGDClassifier(loss='log', n_jobs=-1, random_state=27, learning_rate='adaptive', eta0=1e-10,
early_stopping=True, validation_fraction=0.1, n_iter_no_change=3)
# 0.2 validation TODO
model.fit(X_train, y_train)
CPU times: user 1.06 s, sys: 511 ms, total: 1.57 s Wall time: 1.15 s
model.n_iter_
4
Did it stop too early? Let's retrain with a few more iterations to see. Note that SGDClassifier has a parameter called validation_fraction which splits a validation set from the training data to determine when it stops.
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=27)
model2 = SGDClassifier(loss='log', n_jobs=-1, random_state=27, learning_rate='adaptive', eta0=1e-10)
epochs = 30
train_acc = np.zeros(epochs)
valid_acc = np.zeros(epochs)
for i in tqdm(range(epochs)):
model2.partial_fit(X_train, y_train, np.unique(y_train))
#log
train_acc[i] += np.round(accuracy_score(y_train, model2.predict(X_train)),3)
valid_acc[i] += np.round(accuracy_score(y_valid, model2.predict(X_valid)),3)
plt.plot(train_acc, label='train')
plt.plot(valid_acc, label='valid')
plt.title('CaDoD Training')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
del model2
expLog = pd.DataFrame(columns=["exp_name",
"Train Acc",
"Valid Acc",
"Test Acc",
"Train MSE",
"Valid MSE",
"Test MSE",
])
exp_name = f"Baseline: Linear Model"
expLog.loc[0,:4] = [f"{exp_name}"] + list(np.round(
[accuracy_score(y_train, model.predict(X_train)),
accuracy_score(y_valid, model.predict(X_valid)),
accuracy_score(y_test_label, model.predict(X_test))],3))
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: FutureWarning: Slicing a positional slice with .loc is not supported, and will raise TypeError in a future version. Use .loc with labels or .iloc with positions instead. """
expLog
| exp_name | Train Acc | Valid Acc | Test Acc | Train MSE | Valid MSE | Test MSE | |
|---|---|---|---|---|---|---|---|
| 0 | Baseline: Linear Model | 0.565 | 0.557 | 0.615 | NaN | NaN | NaN |
y_pred_label = model.predict(X_test)
y_pred_label_proba = model.predict_proba(X_test)
fig, ax = plt.subplots(nrows=2, ncols=5, sharex=False, sharey=False,figsize=(15,6))
ax = ax.flatten()
for i in range(10):
img = X_test[i].reshape(32,32,3)
ax[i].imshow(img)
ax[i].set_title("Ground Truth: {0} \n Prediction: {1} | {2:.2f}".format(idx_to_label[y_test_label[i]],
idx_to_label[y_pred_label[i]],
y_pred_label_proba[i][y_pred_label[i]]),
color=("green" if y_pred_label[i]==y_test_label[i] else "red"))
plt.tight_layout()
plt.show()
Import Required Libraries
# imports
import warnings
warnings.simplefilter('ignore')
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import re
from time import time
from scipy import stats
import json
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import ShuffleSplit
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.linear_model import SGDClassifier
from sklearn.ensemble import RandomForestClassifier
# Create a class to select numerical or categorical columns
# since Scikit-Learn doesn't handle DataFrames yet
class DataFrameSelector(BaseEstimator, TransformerMixin):
def __init__(self, attribute_names):
self.attribute_names = attribute_names
def fit(self, X, y=None):
return self
def transform(self, X):
return X[self.attribute_names].values
# Imports for metrics
from sklearn.model_selection import cross_val_score, ShuffleSplit
# Imports for stats
from scipy import stats
# Convert a number to a percent.
def pct(x):
return round(100*x,1)
# Set up reporting
import pandas as pd
results = pd.DataFrame(columns=["ExpID", "Cross fold train accuracy", "Test Accuracy", "p-value", "Train Time(s)", "Test Time(s)", "Experiment description"])
# Set up ShuffleSplit for p_value testing
cv = ShuffleSplit(n_splits=30, test_size=0.3, random_state=0)
def ttest(control, treatment):
#paired t-test; two-tailed p-value A , B
(t_score, p_value) = stats.ttest_rel(control, treatment)
if p_value > 0.05/2: #Two sided
print('There is no significant difference between the two machine learning pipelines (Accept H0)')
else:
print('The two machine learning pipelines are different (reject H0) \n(t_score, p_value) = (%.2f, %.5f)'%(t_score, p_value) )
if t_score > 0.0: #in the case of regression lower RMSE is better; A is lower
print('Machine learning pipeline A is better than B')
else:
print('Machine learning pipeline B is better than A')
return p_value
data = df
y_data = data['Label']
x_data = data.drop(['Label', 'XClick1X', 'XClick2X','XClick3X','XClick4X','XClick1Y', 'XClick2Y','XClick3Y','XClick4Y','Source'], axis = 1)
corr_matrix = data.corr()
corr_matrix["Label"].sort_values(ascending=False)
Label 1.000000 IsOccluded 0.397112 XClick4X 0.167493 XClick2Y 0.134961 XClick1Y 0.132553 XClick4Y 0.109203 XMin 0.103046 YMin 0.083317 XClick3X 0.077203 XClick1X 0.072189 XClick2X 0.041927 XClick3Y 0.035337 YMax 0.029408 IsTruncated 0.020644 IsDepiction 0.018157 IsGroupOf 0.003171 IsInside -0.030458 XMax -0.085277 Confidence NaN Name: Label, dtype: float64
# Correlation observation
from pandas.plotting import scatter_matrix
# Top four correlated inputs with survived
attributes = ["Label", "IsOccluded", "IsTruncated", "IsInside", "IsGroupOf"]
scatter_matrix(data[attributes], figsize=(12, 8));
cat_vars = ['IsOccluded', 'IsTruncated', 'IsInside','IsGroupOf']
plt.figure(figsize=(15,4))
for idx, cat in enumerate(cat_vars):
plt.subplot(1, 4, idx+1)
sns.countplot(data[cat], hue=data['LabelName'])
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, shuffle=True, random_state=42)
from sklearn.pipeline import Pipeline, FeatureUnion
# Identify the numeric features we wish to consider.
num_attribs = [
'IsOccluded',
'IsTruncated',
'IsInside',
'IsGroupOf'
]
# Create a pipeline for the numeric features.
# Use DataFrameSelector with the numeric features defined above
# Use StandardScaler() to standardize the data
# Missing values will be imputed using the feature median.
num_pipeline = Pipeline([('selector', DataFrameSelector(num_attribs)),
('imputer', SimpleImputer(strategy="median")),
('standard_scaler', StandardScaler()),])
# Identify the categorical features we wish to consider.
cat_attribs = [
"LabelName"
]
# Identiy the range of expected values for the categorical features.
cat_values = [
['Dog','Cat'], # Embarked
]
# Create a pipelne for the categorical features.
# Entries with missing values or values that don't exist in the range
# defined above will be one hot encoded as zeroes.
cat_pipeline = Pipeline([
('selector', DataFrameSelector(cat_attribs)),
('imputer', SimpleImputer(strategy='most_frequent')),
('ohe', OneHotEncoder(sparse=False, handle_unknown="ignore"))
])
full_pipeline = FeatureUnion(transformer_list=[("num_pipeline", num_pipeline),("cat_pipeline", cat_pipeline),]) #TODO <- ColumnTransformer
X = np.load('data/img.npy', allow_pickle=True)
y_label = np.load('data/y_label.npy', allow_pickle=True)
y_bbox = np.load('data/y_bbox.npy', allow_pickle=True)
# y is y_label
X_train_class, X_test_class, y_train_label, y_test_label = train_test_split(X, y_label, stratify=y_label, shuffle=True, test_size=0.50, random_state=27)
idx_to_label = {1:'dog', 0:'cat'} # encoder
# scale data
np.random.seed(42)
X = X.astype(np.float32) / 255.
y_label=y_label.astype(int)
X_train_class, X_test_class, y_train_label, y_test_label = train_test_split(X, y_label, stratify=y_label, shuffle=True, test_size=0.20, random_state=27)
X_train_full = X_train_class
y_train_full = y_train_label
X_test_full = X_test_class
y_test_full = y_test_label
X_train, _, y_train, _ = train_test_split(X_train_class, y_train_label, stratify=y_train_label, train_size=0.1, random_state=42)
X_test, _, y_test, _ = train_test_split(X_test_class, y_test_label, stratify=y_test_label, train_size=0.1, random_state=42)
# use full pipeline above to build full pipeline with predictor
np.random.seed(42)
full_pipeline_with_predictor = Pipeline([
("linear", LogisticRegression(random_state=42))
])
start = time()
x_train = X_train
y_train = y_train
full_pipeline_with_predictor.fit(x_train, y_train)
np.random.seed(42)
start = time()
x_train = X_train
y_train = y_train
full_pipeline_with_predictor.fit(x_train, y_train)
np.random.seed(42)
# Set up cross validation scores
# Use ShuffleSplit() with 30 splits, 30% test_size
# and a random seed of 0
#==================================================#
# Your code starts here #
#==================================================#
cv3Splits = ShuffleSplit(n_splits=3, random_state=0, test_size=0.30, train_size=None)
logit_scores = cross_val_score(full_pipeline_with_predictor, x_train, y_train, cv=cv3Splits)
#==================================================#
# Your code ends here #
# Please don't add code below here #
#==================================================#
logit_score_train = logit_scores.mean()
train_time = np.round(time() - start, 4)
# Time and score test predictions
x_test = X_test
y_test = y_test
start = time()
logit_score_test = full_pipeline_with_predictor.score(x_test, y_test)
test_time = np.round(time() - start, 4)
results.loc[0] = ["Baseline", pct(logit_score_train), np.round(pct(logit_score_test),3),
"---", train_time, test_time, "Untuned LogisticRegression"]
results
| ExpID | Cross fold train accuracy | Test Accuracy | p-value | Train Time(s) | Test Time(s) | Experiment description | |
|---|---|---|---|---|---|---|---|
| 0 | Baseline | 54.0 | 55.2 | --- | 1.4712 | 0.0085 | Untuned LogisticRegression |
# A Function to execute the grid search and record the results.
def ConductGridSearch(X_train, y_train, X_test, y_test, i=0, prefix='', n_jobs=-1,verbose=1):
# Create a list of classifiers for our grid search experiment
classifiers = [
('Logistic Regression', LogisticRegression(random_state=42)),
('K-Nearest Neighbors', KNeighborsClassifier()),
('Naive Bayes', GaussianNB()),
('Support Vector', SVC(random_state=42)),
('Stochastic GD', SGDClassifier(loss='log',
penalty='l2',
early_stopping=True,
max_iter=10000, tol=1e-5,
random_state=42)),
('RandomForest', RandomForestClassifier(random_state=42)),
('xgb', xgb.XGBRegressor(validation_fraction=0.2,
n_iter_no_change=5, tol=0.01,
random_state=0, verbose=1))
]
# Arrange grid search parameters for each classifier
params_grid = {
'Logistic Regression': {
'penalty': ('l1', 'l2'),
'tol': (0.0001, 0.00001, 0.0000001),
'C': (10, 1, 0.1, 0.01),
},
'K-Nearest Neighbors': {
'n_neighbors': (3, 5, 7, 8, 11),
'p': (1,2),
},
'Naive Bayes': {},
'Support Vector' : {
'kernel': ('rbf', 'poly'),
'degree': (1, 2, 3, 4, 5),
'C': (10, 1, 0.1, 0.01),
},
'Stochastic GD': {
'tol': (0.0001, 0.0000001),
'alpha': (0.1, 0.001, 0.0001),
},
'RandomForest': {
'max_depth': [9, 15, 22],
'max_features': [3, 5, 50],
'min_samples_split': [2, 5, 15],
'min_samples_leaf': [2, 3, 5],
'bootstrap': [False],
'n_estimators':[20, 80, 300]},
'xgb': {
'n_estimators':[20,80,300]
},
}
for (name, classifier) in classifiers:
i += 1
# Print classifier and parameters
print('****** START',prefix, name,'*****')
parameters = params_grid[name]
print("Parameters:")
for p in sorted(parameters.keys()):
print("\t"+str(p)+": "+ str(parameters[p]))
# generate the pipeline
full_pipeline_with_predictor = Pipeline([
#("preparation", full_pipeline),
("predictor", classifier)
])
# Execute the grid search
params = {}
for p in parameters.keys():
pipe_key = 'predictor__'+str(p)
params[pipe_key] = parameters[p]
grid_search = GridSearchCV(full_pipeline_with_predictor, params, scoring='accuracy', cv=5,
n_jobs=n_jobs, verbose=verbose)
grid_search.fit(X_train, y_train)
# Best estimator score
best_train = pct(grid_search.best_score_)
# Best estimator fitting time
start = time()
grid_search.best_estimator_.fit(X_train, y_train)
train_time = round(time() - start, 4)
# Best estimator prediction time
start = time()
best_test_accuracy = pct(grid_search.best_estimator_.score(X_test, y_test))
test_time = round(time() - start, 4)
# Generate 30 training accuracy scores with the best estimator and 30-split CV
# To calculate the best_train_accuracy use the pct() and mean() methods
#==================================================#
# Your code starts here #
#==================================================#
best_train_scores = cross_val_score(grid_search.best_estimator_, X_train, y_train, cv=cv3Splits)
best_train_accuracy = pct(best_train_scores.mean())
#==================================================#
# Your code ends here #
# Please don't add code below here #
#==================================================#
# Conduct t-test with baseline logit (control) and best estimator (experiment)
(t_stat, p_value) = stats.ttest_rel(logit_scores, best_train_scores)
# Collect the best parameters found by the grid search
print("Best Parameters:")
best_parameters = grid_search.best_estimator_.get_params()
param_dump = []
for param_name in sorted(params.keys()):
param_dump.append((param_name, best_parameters[param_name]))
print("\t"+str(param_name)+": " + str(best_parameters[param_name]))
print("****** FINISH",prefix,name," *****")
print("")
# Record the results
results.loc[i] = [prefix+name, best_train_accuracy, best_test_accuracy, round(p_value,5), train_time, test_time, json.dumps(param_dump)]
%%time
import xgboost as xgb
# This might take a while
if __name__ == "__main__":
ConductGridSearch(X_train, y_train, X_test, y_test, 0, "Best Model:", n_jobs=-1,verbose=1)
****** START Best Model: Logistic Regression *****
Parameters:
C: (10, 1, 0.1, 0.01)
penalty: ('l1', 'l2')
tol: (0.0001, 1e-05, 1e-07)
Fitting 5 folds for each of 24 candidates, totalling 120 fits
Best Parameters:
predictor__C: 0.01
predictor__penalty: l2
predictor__tol: 0.0001
****** FINISH Best Model: Logistic Regression *****
****** START Best Model: K-Nearest Neighbors *****
Parameters:
n_neighbors: (3, 5, 7, 8, 11)
p: (1, 2)
Fitting 5 folds for each of 10 candidates, totalling 50 fits
Best Parameters:
predictor__n_neighbors: 11
predictor__p: 2
****** FINISH Best Model: K-Nearest Neighbors *****
****** START Best Model: Naive Bayes *****
Parameters:
Fitting 5 folds for each of 1 candidates, totalling 5 fits
Best Parameters:
****** FINISH Best Model: Naive Bayes *****
****** START Best Model: Support Vector *****
Parameters:
C: (10, 1, 0.1, 0.01)
degree: (1, 2, 3, 4, 5)
kernel: ('rbf', 'poly')
Fitting 5 folds for each of 40 candidates, totalling 200 fits
Best Parameters:
predictor__C: 1
predictor__degree: 1
predictor__kernel: poly
****** FINISH Best Model: Support Vector *****
****** START Best Model: Stochastic GD *****
Parameters:
alpha: (0.1, 0.001, 0.0001)
tol: (0.0001, 1e-07)
Fitting 5 folds for each of 6 candidates, totalling 30 fits
Best Parameters:
predictor__alpha: 0.1
predictor__tol: 0.0001
****** FINISH Best Model: Stochastic GD *****
****** START Best Model: RandomForest *****
Parameters:
bootstrap: [False]
max_depth: [9, 15, 22]
max_features: [3, 5, 50]
min_samples_leaf: [2, 3, 5]
min_samples_split: [2, 5, 15]
n_estimators: [20, 80, 300]
Fitting 5 folds for each of 243 candidates, totalling 1215 fits
Best Parameters:
predictor__bootstrap: False
predictor__max_depth: 15
predictor__max_features: 5
predictor__min_samples_leaf: 5
predictor__min_samples_split: 15
predictor__n_estimators: 300
****** FINISH Best Model: RandomForest *****
****** START Best Model: xgb *****
Parameters:
n_estimators: [20, 80, 300]
Fitting 5 folds for each of 3 candidates, totalling 15 fits
[07:49:02] WARNING: /workspace/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.
[07:49:05] WARNING: /workspace/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.
[07:49:09] WARNING: /workspace/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.
[07:49:11] WARNING: /workspace/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.
[07:49:14] WARNING: /workspace/src/objective/regression_obj.cu:152: reg:linear is now deprecated in favor of reg:squarederror.
Best Parameters:
predictor__n_estimators: 20
****** FINISH Best Model: xgb *****
CPU times: user 59.6 s, sys: 3.6 s, total: 1min 3s
Wall time: 38min 43s
results
| ExpID | Cross fold train accuracy | Test Accuracy | p-value | Train Time(s) | Test Time(s) | Experiment description | |
|---|---|---|---|---|---|---|---|
| 0 | Baseline | 54.0 | 55.2 | --- | 2.0954 | 0.0039 | Untuned LogisticRegression |
| 1 | Best Model:Logistic Regression | 56.9 | 53.3 | 0.0198 | 0.5212 | 0.0035 | [["predictor__C", 0.01], ["predictor__penalty"... |
| 2 | Best Model:K-Nearest Neighbors | 53.0 | 56.0 | 0.51119 | 0.0036 | 0.1135 | [["predictor__n_neighbors", 11], ["predictor__... |
| 3 | Best Model:Naive Bayes | 54.0 | 55.2 | 1 | 0.0260 | 0.0129 | [] |
| 4 | Best Model:Support Vector | 56.1 | 52.9 | 0.031 | 2.5276 | 0.6033 | [["predictor__C", 1], ["predictor__degree", 1]... |
| 5 | Best Model:Stochastic GD | 52.9 | 57.9 | 0.55189 | 0.1817 | 0.0036 | [["predictor__alpha", 0.1], ["predictor__tol",... |
| 6 | Best Model:RandomForest | 54.0 | 57.1 | 1 | 1.9332 | 0.0664 | [["predictor__bootstrap", false], ["predictor_... |
| 7 | Best Model:xgb | -0.2 | -1.6 | 0.00049 | 3.4945 | 0.0102 | [["predictor__n_estimators", 20]] |
df.columns
Index(['ImageID', 'Source', 'LabelName', 'Confidence', 'XMin', 'XMax', 'YMin',
'YMax', 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction',
'IsInside', 'XClick1X', 'XClick2X', 'XClick3X', 'XClick4X', 'XClick1Y',
'XClick2Y', 'XClick3Y', 'XClick4Y', 'ImageArray', 'Label'],
dtype='object')
Z=df.copy()
Z.drop(columns=['Label','XMin', 'XMax', 'YMin','YMax','ImageID', 'Source', 'LabelName'],inplace=True)
Z.drop(columns=["ImageArray"],inplace=True)
# np.concatenate([df['ImageArray'].iloc[0],np.array([1,2])])
Z.columns
Index(['Confidence', 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction',
'IsInside', 'XClick1X', 'XClick2X', 'XClick3X', 'XClick4X', 'XClick1Y',
'XClick2Y', 'XClick3Y', 'XClick4Y'],
dtype='object')
columns=Z.columns
final=[]
W=[]
for i in range(0,len(df)):
arr=[]
for each in columns:
arr.append(Z[each].iloc[i])
# final.append(np.concatenate([df['ImageArray'].iloc[i],np.array(arr)]))
final.append(df['ImageArray'].iloc[i])
np.array(final).shape
(12966, 3072)
classDf=df.copy()
classDf.drop(columns=['ImageID', 'Source', 'LabelName', 'Confidence', 'ImageArray', 'Label'],inplace=True)
columns=classDf.columns
classification=[]
for i in range(0,len(df)):
arr1=[]
for each in columns:
arr1.append(classDf[each].iloc[i])
# classification.append(np.concatenate([df['ImageArray'].iloc[i],np.array(arr1)]))
classification.append(df['ImageArray'].iloc[i])
np.array(classification).shape
(12966, 3072)
np.array(final).shape
(12966, 3072)
X_train, X_test, y_train, y_test = train_test_split(np.array(final), y_bbox, test_size=0.01, random_state=27)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=27)
%%time
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from time import time
#
basepipeline_LR = Pipeline([
("std_scaler", StandardScaler()),
('LinearRegression', LinearRegression())
])
start = time()
basepipeline_LR.fit(X_train, y_train)
train_time = np.round(time() - start, 4)
# might take a few minutes to train
#CPU times: user 1h 26min 40s, sys: 5min 53s, total: 1h 32min 34s
#Wall time: 17min 24s
CPU times: user 45.9 s, sys: 1.74 s, total: 47.7 s Wall time: 25.2 s
from sklearn.metrics import mean_absolute_error, r2_score
def mean_absolute_percentage_error(y_true, y_pred):
return np.mean(np.abs((y_true.ravel() - y_pred.ravel()) / y_true.ravel())) * 100
expLog_LR = pd.DataFrame(columns=["exp_name",
"Train RMSE",
"Valid RMSE",
"Test RMSE",
"Train MAE",
"Valid MAE",
"Test MAE",
"Train time"
])
print(mean_squared_error(y_train, basepipeline_LR.predict(X_train)),
mean_squared_error(y_valid, basepipeline_LR.predict(X_valid)),
mean_squared_error(y_test, basepipeline_LR.predict(X_test)))
exp_name = f"Baseline: Linear Regression Model"
expLog_LR.loc[0,:10] = [f"{exp_name}"]+list(np.round([np.sqrt(mean_squared_error(y_train, basepipeline_LR.predict(X_train))),
np.sqrt(mean_squared_error(y_valid, basepipeline_LR.predict(X_valid))),
np.sqrt(mean_squared_error(y_test, basepipeline_LR.predict(X_test))),
mean_absolute_error(y_train, basepipeline_LR.predict(X_train)),
mean_absolute_error(y_valid, basepipeline_LR.predict(X_valid)),
mean_absolute_error(y_test, basepipeline_LR.predict(X_test)),
train_time
],3))
expLog_LR
0.0071202878133270545 0.01518516265068297 0.014661098016082917
| exp_name | Train RMSE | Valid RMSE | Test RMSE | Train MAE | Valid MAE | Test MAE | Train time | |
|---|---|---|---|---|---|---|---|---|
| 0 | Baseline: Linear Regression Model | 0.084 | 0.123 | 0.121 | 0.066 | 0.096 | 0.094 | 25.227 |
%%time
from sklearn.linear_model import Lasso
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
#
estimators = [('ridge', Ridge()),
('lasso', Lasso())]
best_score = []
best_param = []
start = time()
for estimator in estimators:
params = {estimator[0]+'__alpha':[.05, .1, .5, 1]}
pipe = Pipeline([('std_scaler', StandardScaler()),
(estimator[0], estimator[1])])
gs = GridSearchCV(pipe,params,scoring='neg_mean_squared_error')
gs.fit(X_train, y_train)
best_score.append(gs.best_score_)
best_param.append(gs.best_params_)
best_idx = np.argmax(best_score)
train_time = np.round(time() - start, 4)
print('Best model is:', estimators[best_idx][0], 'with parameter', best_param[best_idx])
Best model is: lasso with parameter {'lasso__alpha': 0.05}
CPU times: user 3min 1s, sys: 21.2 s, total: 3min 22s
Wall time: 1min 58s
Best parameter Alpha:
estimator = [Ridge, Lasso][best_idx]
param = list(gs.best_params_.values())[0]
print (param)
LR_Ridge_Lasso_Pregularized_pipe = Pipeline([
('std_scalar', StandardScaler()),
('estimator', estimator(alpha=param))])
start = time()
LR_Ridge_Lasso_Pregularized_pipe.fit(X_train, y_train)
train_time = np.round(time() - start, 4)
0.05
Experiment Evaluation for Ridge/Lasso:
exp_name = f"Linear Regression(best regularization and alpha)"
expLog_LR.loc[len(expLog_LR)] = [f"{exp_name}"]+list(np.round([np.sqrt(mean_squared_error(y_train, LR_Ridge_Lasso_Pregularized_pipe.predict(X_train))),
np.sqrt(mean_squared_error(y_valid, LR_Ridge_Lasso_Pregularized_pipe.predict(X_valid))),
np.sqrt(mean_squared_error(y_test, LR_Ridge_Lasso_Pregularized_pipe.predict(X_test))),
mean_absolute_error(y_train, LR_Ridge_Lasso_Pregularized_pipe.predict(X_train)),
mean_absolute_error(y_valid, LR_Ridge_Lasso_Pregularized_pipe.predict(X_valid)),
mean_absolute_error(y_test, LR_Ridge_Lasso_Pregularized_pipe.predict(X_test)),
train_time
],3))
expLog_LR
| exp_name | Train RMSE | Valid RMSE | Test RMSE | Train MAE | Valid MAE | Test MAE | Train time | |
|---|---|---|---|---|---|---|---|---|
| 0 | Baseline: Linear Regression Model | 0.084 | 0.123 | 0.121 | 0.066 | 0.096 | 0.094 | 25.227 |
| 1 | Linear Regression(best regularization and alpha) | 0.102 | 0.103 | 0.096 | 0.082 | 0.083 | 0.079 | 1.776 |
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.multioutput import MultiOutputRegressor
Random_Forest_LR = Pipeline([
('std_scaler', StandardScaler()),
('Random_Forest', MultiOutputRegressor(RandomForestRegressor()))
])
grid_forest_fr = {
'Random_Forest__estimator__max_depth': [3, 5],
'Random_Forest__estimator__max_features': [2, 3],
'Random_Forest__estimator__min_samples_leaf': [3, 4],
'Random_Forest__estimator__min_samples_split': [2, 5],
'Random_Forest__estimator__n_estimators': [80, 100]
}
random_forest_LR = (GridSearchCV(estimator=Random_Forest_LR,
param_grid=grid_forest_fr,
cv=2,
scoring = 'neg_mean_squared_error',
n_jobs = -1))
random_forest_LR.estimator.get_params().keys()
start = time()
random_forest_LR = random_forest_LR.fit(X_train,y_train)
train_time = np.round(time() - start, 4)
random_forest_LR.best_estimator_
Pipeline(steps=[('std_scaler', StandardScaler()),
('Random_Forest',
MultiOutputRegressor(estimator=RandomForestRegressor(max_depth=5,
max_features=3,
min_samples_leaf=4,
min_samples_split=5)))])
exp_name = f"Multi output Random forest regressor"
expLog_LR.loc[len(expLog_LR)] = [f"{exp_name}"]+list(np.round([np.sqrt(mean_squared_error(y_train, random_forest_LR.predict(X_train))),
np.sqrt(mean_squared_error(y_valid, random_forest_LR.predict(X_valid))),
np.sqrt(mean_squared_error(y_test, random_forest_LR.predict(X_test))),
mean_absolute_error(y_train, random_forest_LR.predict(X_train)),
mean_absolute_error(y_valid, random_forest_LR.predict(X_valid)),
mean_absolute_error(y_test, random_forest_LR.predict(X_test)),
train_time
],3))
expLog_LR
| exp_name | Train RMSE | Valid RMSE | Test RMSE | Train MAE | Valid MAE | Test MAE | Train time | |
|---|---|---|---|---|---|---|---|---|
| 0 | Baseline: Linear Regression Model | 0.084 | 0.123 | 0.121 | 0.066 | 0.096 | 0.094 | 25.227 |
| 1 | Linear Regression(best regularization and alpha) | 0.102 | 0.103 | 0.096 | 0.082 | 0.083 | 0.079 | 1.776 |
| 2 | Multi output Random forest regressor | 0.098 | 0.101 | 0.095 | 0.079 | 0.081 | 0.077 | 132.356 |
expLog.iloc[0,4:] = list(np.round([mean_squared_error(y_train, basepipeline_LR.predict(X_train)),
mean_squared_error(y_valid, basepipeline_LR.predict(X_valid)),
mean_squared_error(y_test, basepipeline_LR.predict(X_test))],3))
expLog
| exp_name | Train Acc | Valid Acc | Test Acc | Train MSE | Valid MSE | Test MSE | |
|---|---|---|---|---|---|---|---|
| 0 | Baseline: Linear Model | 0.565 | 0.557 | 0.615 | 0.007 | 0.015 | 0.015 |
Our objective of the classification model is to predict the class that is whether the image contains a cat or a dog. We have compared across varied different classification models and different parameters which resulted in a maximum accuracy of 57 % for Stochastic gradient descent and the next best came up for random forest.We could observe that the majority of models have the test accuracy in the range of 52-57%. We will try to get better results by simplifying the images and by applying dimension and feature reduction along wth deep learning algorithms to improve speed and accuracy.
With regression models we have predicted the bounding box coordinates along with regression metrics for three different regressors. We tried to implement Baseline Linear regrssor, LR with Lasso and Rigde Regularization and Random forest regressor. We could see that LR with Lasso and Rigde Regularization has provided better metrics.
The main challenges was to work with a huge dataset and when shifted to colab we were running out of RAM.
To overcome this issue we had to decrease the standardized image size from 128x128 to 32x32.
However this transformation lead to loss of information and predicted distorted images which caused poor prediction of images
In phase 1, we have focused on the SKLearn Baseline models for Logistic Regression, SGDClassifier to classify the images into cats and dogs and Linear Regression for marking the bounding boxes around the cats and dogs inside the image.
We have also implemented the Homegrown Logistic Regression and obtained accuracy about 52.6% and also calculated CXE+MSE loss functions.
Plan to implement multi task neural networks for our next phase and try to improve the models accuracy using pytorch, CNN and efficientdet detector.
Implement a Homegrown Linear Regression model that has four target values.Extend the MSE loss function from one target to four targets (x, y, w, h).
Implement a Homegrown Logistic Regression model. Extend the loss function from CXE to CXE + MSE, i.e., make it a complex multitask loss function where the resulting model predicts the class and bounding box coordinates at the same time
Implement a Homegrown Linear Regression model that has four target values.Extend the MSE loss function from one target to four targets (x, y, w, h).
def normalize(X):
# X --> Input.
# m-> number of training examples
# n-> number of features
m, n = X.shape
# Normalizing all the n features of X.
for i in range(n):
X = (X - X.mean(axis=0))/X.std(axis=0)
# scale data
np.random.seed(42)
if np.max(X) > 4.:
X = X.astype(np.float32) / 255.
return X
final = normalize(np.array(final))
# classification = normalize(np.array(classification))
X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(np.array(final), y_bbox, test_size=0.01, random_state=27)
X_train_r, X_valid_r, y_train_r, y_valid_r = train_test_split(X_train_r, y_train_r, test_size=0.1, random_state=27)
X_train_c, X_test_c, y_train_c, y_test_c = train_test_split(np.array(final), y_label, test_size=0.01, random_state=27)
X_train_c, X_valid_c, y_train_c, y_valid_c = train_test_split(X_train_c, y_train_c, test_size=0.1, random_state=27)
# scale data for reg
np.random.seed(42)
if np.max(X_train_r) > 4.:
X_train_r = X_train_r.astype(np.float32) / 255.
if np.max(X_valid_r) > 4.:
X_valid_r = X_valid_r.astype(np.float32) / 255.
y_train_r=y_train_r.astype(int)
y_valid_r=y_valid_r.astype(int)
# scale data for classification
np.random.seed(42)
if np.max(X_train_c) > 4.:
X_train_c = X_train_c.astype(np.float32) / 255.
if np.max(X_valid_c) > 4.:
X_valid_c = X_valid_c.astype(np.float32) / 255.
y_train_c=y_train_c.astype(int)
y_valid_c=y_valid_c.astype(int)
import warnings
warnings.filterwarnings('ignore')
class LinearRegressionHomegrown(object):
def __init__(self):
"""
Constructor for the homgrown Linear Regression
Args:
None
Return:
None
"""
self.coef_r = None # weight vector
self.intercept_r = None # bias term
self._thetaReg = None # augmented weight vector, i.e., bias + weights
# this allows to treat all decision variables homogeneously
self.history = {"MSE_train": [],
"Reg_train_MSE":[],
"val_MSE":[],
"Reg_val_MSE":[]}
def _gradReg(self, X, y):
# number of training examples
n = X.shape[0]
# get scores for each class and example
# 2D matrix
scores = self._predict_raw(X)
gradient = np.dot(X.T, scores) / n
return gradient
def _gd(self, X_r, y_r, max_iter, alpha, X_val_r, y_val_r):
"""
Runs Full GD and logs error, weigths, gradient at every step
Args:
X(ndarray): train objects
y(ndarray): answers for train objects
max_iter(int): number of weight updates
alpha(floar): step size in direction of gradient
Return:
None
"""
for i in range(max_iter):
metrics = self.score(X_r, y_r)
print("Epoch: ",i+1,"- ", metrics)
self.history["Reg_train_MSE"].append(metrics["Reg_MSE"])
if X_val_r is not None:
metrics_val = self.score(X_val_r, y_val_r)
self.history["Reg_val_MSE"].append(metrics_val["Reg_MSE"])
# calculate gradient for regressor
grad_reg = self._gradReg(X_r, y_r)
# do gradient step
self._thetaReg -= alpha * grad_reg
def fit(self, X_r,y_r, max_iter=1000, alpha=0.05, val_data_r=None):
"""
Public API to fit Logistic regression model
Args:
X(ndarray): train objects
y(ndarray): answers for train objects
max_iter(int): number of weight updates
alpha(floar): step size in direction of gradient
Return:
None
"""
# Augment the data with the bias term.
# So we can treat the the input variables and the bias term homogeneously
# from a vectorization perspective
X_r = np.c_[np.ones(X_r.shape[0]), X_r]
if val_data_r is not None:
X_val_r, y_val_r = val_data_r
X_val_r = np.c_[np.ones(X_val_r.shape[0]), X_val_r]
else:
X_val_r = None
y_val_r = None
# initialize if the first step
if self._thetaReg is None:
self._thetaReg = np.random.rand(X_r.shape[1], 4)
# do full gradient descent
self._gd(X_r, y_r, max_iter, alpha, X_val_r, y_val_r)
# get final weigths and bias
self.intercept_r = self._thetaReg[0]
self.coef_r = self._thetaReg[1:]
def score(self, X_r, y_r):
# number of training samples
n1 = X_r.shape[0]
# get scores
scores_r = self._predict_raw(X_r)
pred_r=scores_r
# exp=2.73
# probs = 1.0/(1 + exp**(-scores_c))
# # print("--------probs-------",probs)
# # ind=np.argmax(scores_c,axis=1)
# pred_c=[]
# for i,ind in enumerate(np.argmax(scores_c,axis=1)):
# pred_c.append(int(probs[i][ind]))
# acc = accuracy_score(y_c, np.array(pred_c))
# pred_c=np.array(pred_c)
# # trasnform scores to probabilities
# exp_scores = np.exp(-scores_c)
# probs1 = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
# # print('probs1',probs1,"\nexp_scores",exp_scores)
# # logloss per each example
# corect_logprobs = (probs[range(n2),y_c])
# # print("\ncorrect log prob",corect_logprobs)
# # print(corect_logprobs)
# # total mean logloss
# data_loss = np.sum(corect_logprobs) / n2
# loss=0
# for each in corect_logprobs:
# if (each!=1 and each!=0):
# loss +=y_c * (-np.log (each)) + (1-y_c) * (-np.log (1-each))
metrics = {"Reg_MSE": np.round(mean_squared_error(y_r, pred_r),decimals=10)}
# print(metrics)
return metrics
def _predict_raw(self, X):
"""
Computes scores for each class and each object in X
Args:
X(ndarray): objects
Return:
scores(ndarray): scores for each class and object
"""
# check whether X has appended bias feature or not
if X.shape[1] == len(self._thetaReg):
scores = np.dot(X, self._thetaReg)
else:
scores = np.dot(X, self.coef_r) + self.intercept_r
return scores
def predict(self, X):
"""
Predicts class for each object in X
Args:
X(ndarray): objects
Return:
pred(ndarray): class for each object
"""
# get scores for each class
scores = self._predict_raw(X)
# choose class with maximum score
pred = np.argmax(scores, axis=1)
return pred
model_linear_homegrown = LinearRegressionHomegrown()
#model_lr_homegrown.fit(X_train_r, y_train_r, max_iter=100, alpha=0.1,val_data=[X_valid_r,y_valid_r])
model_linear_homegrown.fit(X_train_r, y_train_r, max_iter=1000, alpha=0.00005,val_data_r=[X_valid_r,y_valid_r])
Epoch: 1 - {'Reg_MSE': 497311.1850962523}
Epoch: 2 - {'Reg_MSE': 464285.7627598468}
Epoch: 3 - {'Reg_MSE': 433482.930881073}
Epoch: 4 - {'Reg_MSE': 404752.3695283258}
Epoch: 5 - {'Reg_MSE': 377953.9445772342}
Epoch: 6 - {'Reg_MSE': 352957.0170035906}
Epoch: 7 - {'Reg_MSE': 329639.799027212}
Epoch: 8 - {'Reg_MSE': 307888.7539284734}
Epoch: 9 - {'Reg_MSE': 287598.0365748122}
Epoch: 10 - {'Reg_MSE': 268668.9718955551}
Epoch: 11 - {'Reg_MSE': 251009.5687307376}
Epoch: 12 - {'Reg_MSE': 234534.0666542402}
Epoch: 13 - {'Reg_MSE': 219162.5135343635}
Epoch: 14 - {'Reg_MSE': 204820.3717467155}
Epoch: 15 - {'Reg_MSE': 191438.1510957394}
Epoch: 16 - {'Reg_MSE': 178951.0666330772}
Epoch: 17 - {'Reg_MSE': 167298.719683862}
Epoch: 18 - {'Reg_MSE': 156424.8005066305}
Epoch: 19 - {'Reg_MSE': 146276.811119326}
Epoch: 20 - {'Reg_MSE': 136805.806923441}
Epoch: 21 - {'Reg_MSE': 127966.155851132}
Epoch: 22 - {'Reg_MSE': 119715.3138466571}
Epoch: 23 - {'Reg_MSE': 112013.6155741232}
Epoch: 24 - {'Reg_MSE': 104824.0793186892}
Epoch: 25 - {'Reg_MSE': 98112.2251184526}
Epoch: 26 - {'Reg_MSE': 91845.9052295512}
Epoch: 27 - {'Reg_MSE': 85995.146087894}
Epoch: 28 - {'Reg_MSE': 80532.0009876945}
Epoch: 29 - {'Reg_MSE': 75430.4127498786}
Epoch: 30 - {'Reg_MSE': 70666.0857027558}
Epoch: 31 - {'Reg_MSE': 66216.3663432992}
Epoch: 32 - {'Reg_MSE': 62060.1320902503}
Epoch: 33 - {'Reg_MSE': 58177.687580184}
Epoch: 34 - {'Reg_MSE': 54550.6679949198}
Epoch: 35 - {'Reg_MSE': 51161.9489433616}
Epoch: 36 - {'Reg_MSE': 47995.5624532058}
Epoch: 37 - {'Reg_MSE': 45036.6186581119}
Epoch: 38 - {'Reg_MSE': 42271.2327940457}
Epoch: 39 - {'Reg_MSE': 39686.4571447047}
Epoch: 40 - {'Reg_MSE': 37270.2176003637}
Epoch: 41 - {'Reg_MSE': 35011.2545172524}
Epoch: 42 - {'Reg_MSE': 32899.0675857942}
Epoch: 43 - {'Reg_MSE': 30923.8644358277}
Epoch: 44 - {'Reg_MSE': 29076.5127253697}
Epoch: 45 - {'Reg_MSE': 27348.4954766745}
Epoch: 46 - {'Reg_MSE': 25731.8694393665}
Epoch: 47 - {'Reg_MSE': 24219.2262753634}
Epoch: 48 - {'Reg_MSE': 22803.6563742302}
Epoch: 49 - {'Reg_MSE': 21478.7151205874}
Epoch: 50 - {'Reg_MSE': 20238.3914472932}
Epoch: 51 - {'Reg_MSE': 19077.0785194009}
Epoch: 52 - {'Reg_MSE': 17989.5464044028}
Epoch: 53 - {'Reg_MSE': 16970.9165940795}
Epoch: 54 - {'Reg_MSE': 16016.6382523988}
Epoch: 55 - {'Reg_MSE': 15122.4660724341}
Epoch: 56 - {'Reg_MSE': 14284.439633205}
Epoch: 57 - {'Reg_MSE': 13498.8641547422}
Epoch: 58 - {'Reg_MSE': 12762.2925565804}
Epoch: 59 - {'Reg_MSE': 12071.5087313079}
Epoch: 60 - {'Reg_MSE': 11423.5119508002}
Epoch: 61 - {'Reg_MSE': 10815.5023283461}
Epoch: 62 - {'Reg_MSE': 10244.8672650893}
Epoch: 63 - {'Reg_MSE': 9709.1688140571}
Epoch: 64 - {'Reg_MSE': 9206.1318995787}
Epoch: 65 - {'Reg_MSE': 8733.6333341091}
Epoch: 66 - {'Reg_MSE': 8289.6915784113}
Epoch: 67 - {'Reg_MSE': 7872.4571947111}
Epoch: 68 - {'Reg_MSE': 7480.2039458595}
Epoch: 69 - {'Reg_MSE': 7111.3204967188}
Epoch: 70 - {'Reg_MSE': 6764.3026769621}
Epoch: 71 - {'Reg_MSE': 6437.7462672382}
Epoch: 72 - {'Reg_MSE': 6130.3402732403}
Epoch: 73 - {'Reg_MSE': 5840.8606546147}
Epoch: 74 - {'Reg_MSE': 5568.1644778937}
Epoch: 75 - {'Reg_MSE': 5311.1844647234}
Epoch: 76 - {'Reg_MSE': 5068.9239086045}
Epoch: 77 - {'Reg_MSE': 4840.451935184}
Epoch: 78 - {'Reg_MSE': 4624.8990828222}
Epoch: 79 - {'Reg_MSE': 4421.4531817444}
Epoch: 80 - {'Reg_MSE': 4229.3555115514}
Epoch: 81 - {'Reg_MSE': 4047.8972182377}
Epoch: 82 - {'Reg_MSE': 3876.4159731431}
Epoch: 83 - {'Reg_MSE': 3714.2928574554}
Epoch: 84 - {'Reg_MSE': 3560.9494569911}
Epoch: 85 - {'Reg_MSE': 3415.8451530187}
Epoch: 86 - {'Reg_MSE': 3278.4745958519}
Epoch: 87 - {'Reg_MSE': 3148.3653488409}
Epoch: 88 - {'Reg_MSE': 3025.0756912286}
Epoch: 89 - {'Reg_MSE': 2908.1925691193}
Epoch: 90 - {'Reg_MSE': 2797.3296845385}
Epoch: 91 - {'Reg_MSE': 2692.1257132373}
Epoch: 92 - {'Reg_MSE': 2592.2426425345}
Epoch: 93 - {'Reg_MSE': 2497.3642210725}
Epoch: 94 - {'Reg_MSE': 2407.1945129205}
Epoch: 95 - {'Reg_MSE': 2321.4565489656}
Epoch: 96 - {'Reg_MSE': 2239.8910690139}
Epoch: 97 - {'Reg_MSE': 2162.2553484692}
Epoch: 98 - {'Reg_MSE': 2088.3221038703}
Epoch: 99 - {'Reg_MSE': 2017.8784719582}
Epoch: 100 - {'Reg_MSE': 1950.7250573026}
Epoch: 101 - {'Reg_MSE': 1886.6750438571}
Epoch: 102 - {'Reg_MSE': 1825.5533661221}
Epoch: 103 - {'Reg_MSE': 1767.1959358912}
Epoch: 104 - {'Reg_MSE': 1711.4489208261}
Epoch: 105 - {'Reg_MSE': 1658.1680713613}
Epoch: 106 - {'Reg_MSE': 1607.2180926762}
Epoch: 107 - {'Reg_MSE': 1558.4720586926}
Epoch: 108 - {'Reg_MSE': 1511.8108652626}
Epoch: 109 - {'Reg_MSE': 1467.1227199027}
Epoch: 110 - {'Reg_MSE': 1424.3026656094}
Epoch: 111 - {'Reg_MSE': 1383.2521364592}
Epoch: 112 - {'Reg_MSE': 1343.8785428491}
Epoch: 113 - {'Reg_MSE': 1306.0948843821}
Epoch: 114 - {'Reg_MSE': 1269.8193885335}
Epoch: 115 - {'Reg_MSE': 1234.9751733629}
Epoch: 116 - {'Reg_MSE': 1201.4899326535}
Epoch: 117 - {'Reg_MSE': 1169.2956419672}
Epoch: 118 - {'Reg_MSE': 1138.3282842109}
Epoch: 119 - {'Reg_MSE': 1108.5275934}
Epoch: 120 - {'Reg_MSE': 1079.8368153967}
Epoch: 121 - {'Reg_MSE': 1052.2024844817}
Epoch: 122 - {'Reg_MSE': 1025.5742146958}
Epoch: 123 - {'Reg_MSE': 999.9045049607}
Epoch: 124 - {'Reg_MSE': 975.1485570525}
Epoch: 125 - {'Reg_MSE': 951.264105567}
Epoch: 126 - {'Reg_MSE': 928.2112590724}
Epoch: 127 - {'Reg_MSE': 905.9523516989}
Epoch: 128 - {'Reg_MSE': 884.4518044679}
Epoch: 129 - {'Reg_MSE': 863.6759957066}
Epoch: 130 - {'Reg_MSE': 843.5931399418}
Epoch: 131 - {'Reg_MSE': 824.1731747051}
Epoch: 132 - {'Reg_MSE': 805.3876547212}
Epoch: 133 - {'Reg_MSE': 787.2096529863}
Epoch: 134 - {'Reg_MSE': 769.6136682758}
Epoch: 135 - {'Reg_MSE': 752.5755386546}
Epoch: 136 - {'Reg_MSE': 736.0723605878}
Epoch: 137 - {'Reg_MSE': 720.0824132802}
Epoch: 138 - {'Reg_MSE': 704.5850878964}
Epoch: 139 - {'Reg_MSE': 689.5608213371}
Epoch: 140 - {'Reg_MSE': 674.9910342686}
Epoch: 141 - {'Reg_MSE': 660.8580731237}
Epoch: 142 - {'Reg_MSE': 647.1451558111}
Epoch: 143 - {'Reg_MSE': 633.8363208863}
Epoch: 144 - {'Reg_MSE': 620.9163799564}
Epoch: 145 - {'Reg_MSE': 608.3708731046}
Epoch: 146 - {'Reg_MSE': 596.1860271344}
Epoch: 147 - {'Reg_MSE': 584.348716449}
Epoch: 148 - {'Reg_MSE': 572.8464263909}
Epoch: 149 - {'Reg_MSE': 561.6672188802}
Epoch: 150 - {'Reg_MSE': 550.799700202}
Epoch: 151 - {'Reg_MSE': 540.2329907994}
Epoch: 152 - {'Reg_MSE': 529.9566969431}
Epoch: 153 - {'Reg_MSE': 519.9608841535}
Epoch: 154 - {'Reg_MSE': 510.2360522612}
Epoch: 155 - {'Reg_MSE': 500.7731119988}
Epoch: 156 - {'Reg_MSE': 491.5633630256}
Epoch: 157 - {'Reg_MSE': 482.5984732894}
Epoch: 158 - {'Reg_MSE': 473.8704596422}
Epoch: 159 - {'Reg_MSE': 465.3716696252}
Epoch: 160 - {'Reg_MSE': 457.0947643503}
Epoch: 161 - {'Reg_MSE': 449.0327024054}
Epoch: 162 - {'Reg_MSE': 441.1787247186}
Epoch: 163 - {'Reg_MSE': 433.52634032}
Epoch: 164 - {'Reg_MSE': 426.0693129427}
Epoch: 165 - {'Reg_MSE': 418.8016484097}
Epoch: 166 - {'Reg_MSE': 411.7175827571}
Epoch: 167 - {'Reg_MSE': 404.8115710464}
Epoch: 168 - {'Reg_MSE': 398.0782768216}
Epoch: 169 - {'Reg_MSE': 391.5125621718}
Epoch: 170 - {'Reg_MSE': 385.1094783593}
Epoch: 171 - {'Reg_MSE': 378.8642569792}
Epoch: 172 - {'Reg_MSE': 372.7723016165}
Epoch: 173 - {'Reg_MSE': 366.829179969}
Epoch: 174 - {'Reg_MSE': 361.0306164085}
Epoch: 175 - {'Reg_MSE': 355.3724849517}
Epoch: 176 - {'Reg_MSE': 349.8508026162}
Epoch: 177 - {'Reg_MSE': 344.4617231373}
Epoch: 178 - {'Reg_MSE': 339.2015310245}
Epoch: 179 - {'Reg_MSE': 334.0666359356}
Epoch: 180 - {'Reg_MSE': 329.0535673497}
Epoch: 181 - {'Reg_MSE': 324.1589695218}
Epoch: 182 - {'Reg_MSE': 319.3795966996}
Epoch: 183 - {'Reg_MSE': 314.71230859}
Epoch: 184 - {'Reg_MSE': 310.1540660576}
Epoch: 185 - {'Reg_MSE': 305.7019270425}
Epoch: 186 - {'Reg_MSE': 301.3530426847}
Epoch: 187 - {'Reg_MSE': 297.1046536426}
Epoch: 188 - {'Reg_MSE': 292.9540865934}
Epoch: 189 - {'Reg_MSE': 288.8987509066}
Epoch: 190 - {'Reg_MSE': 284.9361354789}
Epoch: 191 - {'Reg_MSE': 281.0638057217}
Epoch: 192 - {'Reg_MSE': 277.2794006931}
Epoch: 193 - {'Reg_MSE': 273.5806303651}
Epoch: 194 - {'Reg_MSE': 269.9652730193}
Epoch: 195 - {'Reg_MSE': 266.4311727627}
Epoch: 196 - {'Reg_MSE': 262.9762371588}
Epoch: 197 - {'Reg_MSE': 259.5984349648}
Epoch: 198 - {'Reg_MSE': 256.2957939725}
Epoch: 199 - {'Reg_MSE': 253.0663989439}
Epoch: 200 - {'Reg_MSE': 249.9083896383}
Epoch: 201 - {'Reg_MSE': 246.8199589257}
Epoch: 202 - {'Reg_MSE': 243.7993509817}
Epoch: 203 - {'Reg_MSE': 240.8448595588}
Epoch: 204 - {'Reg_MSE': 237.954826332}
Epoch: 205 - {'Reg_MSE': 235.1276393124}
Epoch: 206 - {'Reg_MSE': 232.3617313272}
Epoch: 207 - {'Reg_MSE': 229.6555785619}
Epoch: 208 - {'Reg_MSE': 227.0076991615}
Epoch: 209 - {'Reg_MSE': 224.4166518877}
Epoch: 210 - {'Reg_MSE': 221.8810348296}
Epoch: 211 - {'Reg_MSE': 219.399484165}
Epoch: 212 - {'Reg_MSE': 216.9706729697}
Epoch: 213 - {'Reg_MSE': 214.5933100729}
Epoch: 214 - {'Reg_MSE': 212.2661389561}
Epoch: 215 - {'Reg_MSE': 209.9879366931}
Epoch: 216 - {'Reg_MSE': 207.7575129303}
Epoch: 217 - {'Reg_MSE': 205.5737089041}
Epoch: 218 - {'Reg_MSE': 203.4353964944}
Epoch: 219 - {'Reg_MSE': 201.3414773124}
Epoch: 220 - {'Reg_MSE': 199.2908818213}
Epoch: 221 - {'Reg_MSE': 197.2825684879}
Epoch: 222 - {'Reg_MSE': 195.3155229641}
Epoch: 223 - {'Reg_MSE': 193.3887572974}
Epoch: 224 - {'Reg_MSE': 191.501309168}
Epoch: 225 - {'Reg_MSE': 189.6522411526}
Epoch: 226 - {'Reg_MSE': 187.8406400125}
Epoch: 227 - {'Reg_MSE': 186.0656160064}
Epoch: 228 - {'Reg_MSE': 184.3263022258}
Epoch: 229 - {'Reg_MSE': 182.621853952}
Epoch: 230 - {'Reg_MSE': 180.9514480347}
Epoch: 231 - {'Reg_MSE': 179.3142822905}
Epoch: 232 - {'Reg_MSE': 177.7095749209}
Epoch: 233 - {'Reg_MSE': 176.1365639489}
Epoch: 234 - {'Reg_MSE': 174.5945066736}
Epoch: 235 - {'Reg_MSE': 173.0826791417}
Epoch: 236 - {'Reg_MSE': 171.6003756358}
Epoch: 237 - {'Reg_MSE': 170.1469081785}
Epoch: 238 - {'Reg_MSE': 168.7216060517}
Epoch: 239 - {'Reg_MSE': 167.3238153309}
Epoch: 240 - {'Reg_MSE': 165.9528984332}
Epoch: 241 - {'Reg_MSE': 164.6082336793}
Epoch: 242 - {'Reg_MSE': 163.2892148689}
Epoch: 243 - {'Reg_MSE': 161.995250868}
Epoch: 244 - {'Reg_MSE': 160.7257652089}
Epoch: 245 - {'Reg_MSE': 159.4801957026}
Epoch: 246 - {'Reg_MSE': 158.2579940612}
Epoch: 247 - {'Reg_MSE': 157.0586255323}
Epoch: 248 - {'Reg_MSE': 155.881568544}
Epoch: 249 - {'Reg_MSE': 154.7263143596}
Epoch: 250 - {'Reg_MSE': 153.5923667425}
Epoch: 251 - {'Reg_MSE': 152.4792416309}
Epoch: 252 - {'Reg_MSE': 151.3864668214}
Epoch: 253 - {'Reg_MSE': 150.3135816614}
Epoch: 254 - {'Reg_MSE': 149.2601367507}
Epoch: 255 - {'Reg_MSE': 148.2256936512}
Epoch: 256 - {'Reg_MSE': 147.2098246045}
Epoch: 257 - {'Reg_MSE': 146.2121122576}
Epoch: 258 - {'Reg_MSE': 145.2321493961}
Epoch: 259 - {'Reg_MSE': 144.2695386847}
Epoch: 260 - {'Reg_MSE': 143.3238924151}
Epoch: 261 - {'Reg_MSE': 142.3948322604}
Epoch: 262 - {'Reg_MSE': 141.4819890364}
Epoch: 263 - {'Reg_MSE': 140.5850024695}
Epoch: 264 - {'Reg_MSE': 139.7035209707}
Epoch: 265 - {'Reg_MSE': 138.8372014155}
Epoch: 266 - {'Reg_MSE': 137.9857089305}
Epoch: 267 - {'Reg_MSE': 137.1487166848}
Epoch: 268 - {'Reg_MSE': 136.3259056873}
Epoch: 269 - {'Reg_MSE': 135.51696459}
Epoch: 270 - {'Reg_MSE': 134.7215894954}
Epoch: 271 - {'Reg_MSE': 133.9394837701}
Epoch: 272 - {'Reg_MSE': 133.1703578627}
Epoch: 273 - {'Reg_MSE': 132.4139291263}
Epoch: 274 - {'Reg_MSE': 131.6699216468}
Epoch: 275 - {'Reg_MSE': 130.9380660742}
Epoch: 276 - {'Reg_MSE': 130.2180994593}
Epoch: 277 - {'Reg_MSE': 129.5097650947}
Epoch: 278 - {'Reg_MSE': 128.8128123593}
Epoch: 279 - {'Reg_MSE': 128.1269965673}
Epoch: 280 - {'Reg_MSE': 127.4520788212}
Epoch: 281 - {'Reg_MSE': 126.7878258685}
Epoch: 282 - {'Reg_MSE': 126.1340099615}
Epoch: 283 - {'Reg_MSE': 125.490408722}
Epoch: 284 - {'Reg_MSE': 124.8568050083}
Epoch: 285 - {'Reg_MSE': 124.2329867865}
Epoch: 286 - {'Reg_MSE': 123.618747004}
Epoch: 287 - {'Reg_MSE': 123.0138834678}
Epoch: 288 - {'Reg_MSE': 122.4181987242}
Epoch: 289 - {'Reg_MSE': 121.8314999432}
Epoch: 290 - {'Reg_MSE': 121.2535988046}
Epoch: 291 - {'Reg_MSE': 120.6843113876}
Epoch: 292 - {'Reg_MSE': 120.1234580629}
Epoch: 293 - {'Reg_MSE': 119.5708633884}
Epoch: 294 - {'Reg_MSE': 119.0263560059}
Epoch: 295 - {'Reg_MSE': 118.4897685423}
Epoch: 296 - {'Reg_MSE': 117.9609375117}
Epoch: 297 - {'Reg_MSE': 117.4397032211}
Epoch: 298 - {'Reg_MSE': 116.925909678}
Epoch: 299 - {'Reg_MSE': 116.4194045003}
Epoch: 300 - {'Reg_MSE': 115.9200388284}
Epoch: 301 - {'Reg_MSE': 115.42766724}
Epoch: 302 - {'Reg_MSE': 114.9421476664}
Epoch: 303 - {'Reg_MSE': 114.4633413114}
Epoch: 304 - {'Reg_MSE': 113.9911125721}
Epoch: 305 - {'Reg_MSE': 113.5253289612}
Epoch: 306 - {'Reg_MSE': 113.0658610323}
Epoch: 307 - {'Reg_MSE': 112.6125823059}
Epoch: 308 - {'Reg_MSE': 112.1653691979}
Epoch: 309 - {'Reg_MSE': 111.72410095}
Epoch: 310 - {'Reg_MSE': 111.2886595615}
Epoch: 311 - {'Reg_MSE': 110.8589297227}
Epoch: 312 - {'Reg_MSE': 110.4347987504}
Epoch: 313 - {'Reg_MSE': 110.0161565245}
Epoch: 314 - {'Reg_MSE': 109.6028954268}
Epoch: 315 - {'Reg_MSE': 109.1949102807}
Epoch: 316 - {'Reg_MSE': 108.7920982924}
Epoch: 317 - {'Reg_MSE': 108.3943589945}
Epoch: 318 - {'Reg_MSE': 108.0015941896}
Epoch: 319 - {'Reg_MSE': 107.6137078965}
Epoch: 320 - {'Reg_MSE': 107.2306062968}
Epoch: 321 - {'Reg_MSE': 106.8521976835}
Epoch: 322 - {'Reg_MSE': 106.4783924106}
Epoch: 323 - {'Reg_MSE': 106.1091028438}
Epoch: 324 - {'Reg_MSE': 105.7442433128}
Epoch: 325 - {'Reg_MSE': 105.3837300643}
Epoch: 326 - {'Reg_MSE': 105.0274812165}
Epoch: 327 - {'Reg_MSE': 104.6754167146}
Epoch: 328 - {'Reg_MSE': 104.3274582875}
Epoch: 329 - {'Reg_MSE': 103.9835294052}
Epoch: 330 - {'Reg_MSE': 103.6435552378}
Epoch: 331 - {'Reg_MSE': 103.3074626152}
Epoch: 332 - {'Reg_MSE': 102.9751799876}
Epoch: 333 - {'Reg_MSE': 102.6466373872}
Epoch: 334 - {'Reg_MSE': 102.3217663909}
Epoch: 335 - {'Reg_MSE': 102.0005000841}
Epoch: 336 - {'Reg_MSE': 101.6827730243}
Epoch: 337 - {'Reg_MSE': 101.3685212074}
Epoch: 338 - {'Reg_MSE': 101.0576820329}
Epoch: 339 - {'Reg_MSE': 100.7501942716}
Epoch: 340 - {'Reg_MSE': 100.4459980327}
Epoch: 341 - {'Reg_MSE': 100.1450347328}
Epoch: 342 - {'Reg_MSE': 99.8472470651}
Epoch: 343 - {'Reg_MSE': 99.5525789693}
Epoch: 344 - {'Reg_MSE': 99.2609756026}
Epoch: 345 - {'Reg_MSE': 98.9723833111}
Epoch: 346 - {'Reg_MSE': 98.6867496017}
Epoch: 347 - {'Reg_MSE': 98.4040231154}
Epoch: 348 - {'Reg_MSE': 98.1241536007}
Epoch: 349 - {'Reg_MSE': 97.8470918875}
Epoch: 350 - {'Reg_MSE': 97.572789862}
Epoch: 351 - {'Reg_MSE': 97.3012004422}
Epoch: 352 - {'Reg_MSE': 97.032277554}
Epoch: 353 - {'Reg_MSE': 96.7659761073}
Epoch: 354 - {'Reg_MSE': 96.5022519735}
Epoch: 355 - {'Reg_MSE': 96.2410619631}
Epoch: 356 - {'Reg_MSE': 95.9823638039}
Epoch: 357 - {'Reg_MSE': 95.7261161197}
Epoch: 358 - {'Reg_MSE': 95.4722784096}
Epoch: 359 - {'Reg_MSE': 95.2208110279}
Epoch: 360 - {'Reg_MSE': 94.9716751638}
Epoch: 361 - {'Reg_MSE': 94.7248328229}
Epoch: 362 - {'Reg_MSE': 94.4802468078}
Epoch: 363 - {'Reg_MSE': 94.2378807}
Epoch: 364 - {'Reg_MSE': 93.9976988416}
Epoch: 365 - {'Reg_MSE': 93.7596663184}
Epoch: 366 - {'Reg_MSE': 93.5237489424}
Epoch: 367 - {'Reg_MSE': 93.2899132352}
Epoch: 368 - {'Reg_MSE': 93.0581264116}
Epoch: 369 - {'Reg_MSE': 92.8283563641}
Epoch: 370 - {'Reg_MSE': 92.600571647}
Epoch: 371 - {'Reg_MSE': 92.3747414614}
Epoch: 372 - {'Reg_MSE': 92.1508356407}
Epoch: 373 - {'Reg_MSE': 91.9288246356}
Epoch: 374 - {'Reg_MSE': 91.7086795005}
Epoch: 375 - {'Reg_MSE': 91.4903718797}
Epoch: 376 - {'Reg_MSE': 91.2738739936}
Epoch: 377 - {'Reg_MSE': 91.0591586263}
Epoch: 378 - {'Reg_MSE': 90.8461991119}
Epoch: 379 - {'Reg_MSE': 90.634969323}
Epoch: 380 - {'Reg_MSE': 90.4254436578}
Epoch: 381 - {'Reg_MSE': 90.2175970285}
Epoch: 382 - {'Reg_MSE': 90.0114048496}
Epoch: 383 - {'Reg_MSE': 89.8068430266}
Epoch: 384 - {'Reg_MSE': 89.6038879449}
Epoch: 385 - {'Reg_MSE': 89.4025164591}
Epoch: 386 - {'Reg_MSE': 89.2027058823}
Epoch: 387 - {'Reg_MSE': 89.0044339758}
Epoch: 388 - {'Reg_MSE': 88.8076789391}
Epoch: 389 - {'Reg_MSE': 88.6124194003}
Epoch: 390 - {'Reg_MSE': 88.418634406}
Epoch: 391 - {'Reg_MSE': 88.2263034122}
Epoch: 392 - {'Reg_MSE': 88.0354062755}
Epoch: 393 - {'Reg_MSE': 87.8459232435}
Epoch: 394 - {'Reg_MSE': 87.6578349467}
Epoch: 395 - {'Reg_MSE': 87.4711223893}
Epoch: 396 - {'Reg_MSE': 87.2857669418}
Epoch: 397 - {'Reg_MSE': 87.1017503321}
Epoch: 398 - {'Reg_MSE': 86.9190546378}
Epoch: 399 - {'Reg_MSE': 86.7376622786}
Epoch: 400 - {'Reg_MSE': 86.5575560086}
Epoch: 401 - {'Reg_MSE': 86.3787189091}
Epoch: 402 - {'Reg_MSE': 86.201134381}
Epoch: 403 - {'Reg_MSE': 86.0247861382}
Epoch: 404 - {'Reg_MSE': 85.8496582004}
Epoch: 405 - {'Reg_MSE': 85.6757348867}
Epoch: 406 - {'Reg_MSE': 85.5030008085}
Epoch: 407 - {'Reg_MSE': 85.3314408638}
Epoch: 408 - {'Reg_MSE': 85.1610402304}
Epoch: 409 - {'Reg_MSE': 84.99178436}
Epoch: 410 - {'Reg_MSE': 84.823658972}
Epoch: 411 - {'Reg_MSE': 84.6566500482}
Epoch: 412 - {'Reg_MSE': 84.4907438266}
Epoch: 413 - {'Reg_MSE': 84.3259267959}
Epoch: 414 - {'Reg_MSE': 84.1621856904}
Epoch: 415 - {'Reg_MSE': 83.9995074841}
Epoch: 416 - {'Reg_MSE': 83.8378793862}
Epoch: 417 - {'Reg_MSE': 83.6772888356}
Epoch: 418 - {'Reg_MSE': 83.5177234959}
Epoch: 419 - {'Reg_MSE': 83.3591712509}
Epoch: 420 - {'Reg_MSE': 83.2016201995}
Epoch: 421 - {'Reg_MSE': 83.0450586516}
Epoch: 422 - {'Reg_MSE': 82.8894751229}
Epoch: 423 - {'Reg_MSE': 82.734858331}
Epoch: 424 - {'Reg_MSE': 82.581197191}
Epoch: 425 - {'Reg_MSE': 82.4284808113}
Epoch: 426 - {'Reg_MSE': 82.2766984892}
Epoch: 427 - {'Reg_MSE': 82.1258397074}
Epoch: 428 - {'Reg_MSE': 81.9758941297}
Epoch: 429 - {'Reg_MSE': 81.8268515972}
Epoch: 430 - {'Reg_MSE': 81.6787021248}
Epoch: 431 - {'Reg_MSE': 81.5314358972}
Epoch: 432 - {'Reg_MSE': 81.3850432656}
Epoch: 433 - {'Reg_MSE': 81.2395147442}
Epoch: 434 - {'Reg_MSE': 81.0948410064}
Epoch: 435 - {'Reg_MSE': 80.951012882}
Epoch: 436 - {'Reg_MSE': 80.8080213537}
Epoch: 437 - {'Reg_MSE': 80.6658575538}
Epoch: 438 - {'Reg_MSE': 80.5245127612}
Epoch: 439 - {'Reg_MSE': 80.3839783985}
Epoch: 440 - {'Reg_MSE': 80.2442460287}
Epoch: 441 - {'Reg_MSE': 80.1053073526}
Epoch: 442 - {'Reg_MSE': 79.9671542057}
Epoch: 443 - {'Reg_MSE': 79.8297785559}
Epoch: 444 - {'Reg_MSE': 79.6931725}
Epoch: 445 - {'Reg_MSE': 79.5573282618}
Epoch: 446 - {'Reg_MSE': 79.4222381891}
Epoch: 447 - {'Reg_MSE': 79.2878947512}
Epoch: 448 - {'Reg_MSE': 79.1542905367}
Epoch: 449 - {'Reg_MSE': 79.0214182508}
Epoch: 450 - {'Reg_MSE': 78.8892707128}
Epoch: 451 - {'Reg_MSE': 78.7578408542}
Epoch: 452 - {'Reg_MSE': 78.6271217163}
Epoch: 453 - {'Reg_MSE': 78.4971064477}
Epoch: 454 - {'Reg_MSE': 78.3677883026}
Epoch: 455 - {'Reg_MSE': 78.2391606382}
Epoch: 456 - {'Reg_MSE': 78.1112169132}
Epoch: 457 - {'Reg_MSE': 77.9839506851}
Epoch: 458 - {'Reg_MSE': 77.8573556089}
Epoch: 459 - {'Reg_MSE': 77.7314254347}
Epoch: 460 - {'Reg_MSE': 77.606154006}
Epoch: 461 - {'Reg_MSE': 77.4815352579}
Epoch: 462 - {'Reg_MSE': 77.3575632151}
Epoch: 463 - {'Reg_MSE': 77.2342319903}
Epoch: 464 - {'Reg_MSE': 77.1115357826}
Epoch: 465 - {'Reg_MSE': 76.9894688754}
Epoch: 466 - {'Reg_MSE': 76.8680256351}
Epoch: 467 - {'Reg_MSE': 76.7472005095}
Epoch: 468 - {'Reg_MSE': 76.6269880259}
Epoch: 469 - {'Reg_MSE': 76.5073827899}
Epoch: 470 - {'Reg_MSE': 76.3883794837}
Epoch: 471 - {'Reg_MSE': 76.2699728648}
Epoch: 472 - {'Reg_MSE': 76.1521577642}
Epoch: 473 - {'Reg_MSE': 76.0349290853}
Epoch: 474 - {'Reg_MSE': 75.9182818028}
Epoch: 475 - {'Reg_MSE': 75.8022109604}
Epoch: 476 - {'Reg_MSE': 75.6867116704}
Epoch: 477 - {'Reg_MSE': 75.5717791122}
Epoch: 478 - {'Reg_MSE': 75.4574085306}
Epoch: 479 - {'Reg_MSE': 75.3435952351}
Epoch: 480 - {'Reg_MSE': 75.2303345984}
Epoch: 481 - {'Reg_MSE': 75.1176220552}
Epoch: 482 - {'Reg_MSE': 75.0054531013}
Epoch: 483 - {'Reg_MSE': 74.8938232922}
Epoch: 484 - {'Reg_MSE': 74.7827282421}
Epoch: 485 - {'Reg_MSE': 74.6721636226}
Epoch: 486 - {'Reg_MSE': 74.5621251623}
Epoch: 487 - {'Reg_MSE': 74.4526086448}
Epoch: 488 - {'Reg_MSE': 74.3436099085}
Epoch: 489 - {'Reg_MSE': 74.2351248452}
Epoch: 490 - {'Reg_MSE': 74.1271493992}
Epoch: 491 - {'Reg_MSE': 74.0196795663}
Epoch: 492 - {'Reg_MSE': 73.9127113932}
Epoch: 493 - {'Reg_MSE': 73.8062409759}
Epoch: 494 - {'Reg_MSE': 73.7002644597}
Epoch: 495 - {'Reg_MSE': 73.5947780375}
Epoch: 496 - {'Reg_MSE': 73.4897779494}
Epoch: 497 - {'Reg_MSE': 73.3852604818}
Epoch: 498 - {'Reg_MSE': 73.2812219665}
Epoch: 499 - {'Reg_MSE': 73.1776587801}
Epoch: 500 - {'Reg_MSE': 73.0745673429}
Epoch: 501 - {'Reg_MSE': 72.9719441182}
Epoch: 502 - {'Reg_MSE': 72.869785612}
Epoch: 503 - {'Reg_MSE': 72.7680883714}
Epoch: 504 - {'Reg_MSE': 72.6668489848}
Epoch: 505 - {'Reg_MSE': 72.5660640806}
Epoch: 506 - {'Reg_MSE': 72.4657303267}
Epoch: 507 - {'Reg_MSE': 72.3658444299}
Epoch: 508 - {'Reg_MSE': 72.266403135}
Epoch: 509 - {'Reg_MSE': 72.1674032244}
Epoch: 510 - {'Reg_MSE': 72.0688415172}
Epoch: 511 - {'Reg_MSE': 71.9707148691}
Epoch: 512 - {'Reg_MSE': 71.873020171}
Epoch: 513 - {'Reg_MSE': 71.7757543492}
Epoch: 514 - {'Reg_MSE': 71.6789143641}
Epoch: 515 - {'Reg_MSE': 71.5824972104}
Epoch: 516 - {'Reg_MSE': 71.4864999157}
Epoch: 517 - {'Reg_MSE': 71.3909195406}
Epoch: 518 - {'Reg_MSE': 71.2957531779}
Epoch: 519 - {'Reg_MSE': 71.2009979521}
Epoch: 520 - {'Reg_MSE': 71.1066510189}
Epoch: 521 - {'Reg_MSE': 71.0127095646}
Epoch: 522 - {'Reg_MSE': 70.9191708058}
Epoch: 523 - {'Reg_MSE': 70.8260319888}
Epoch: 524 - {'Reg_MSE': 70.7332903889}
Epoch: 525 - {'Reg_MSE': 70.6409433105}
Epoch: 526 - {'Reg_MSE': 70.5489880861}
Epoch: 527 - {'Reg_MSE': 70.4574220761}
Epoch: 528 - {'Reg_MSE': 70.3662426683}
Epoch: 529 - {'Reg_MSE': 70.2754472775}
Epoch: 530 - {'Reg_MSE': 70.1850333452}
Epoch: 531 - {'Reg_MSE': 70.0949983387}
Epoch: 532 - {'Reg_MSE': 70.0053397515}
Epoch: 533 - {'Reg_MSE': 69.9160551022}
Epoch: 534 - {'Reg_MSE': 69.8271419343}
Epoch: 535 - {'Reg_MSE': 69.738597816}
Epoch: 536 - {'Reg_MSE': 69.6504203398}
Epoch: 537 - {'Reg_MSE': 69.5626071219}
Epoch: 538 - {'Reg_MSE': 69.475155802}
Epoch: 539 - {'Reg_MSE': 69.388064043}
Epoch: 540 - {'Reg_MSE': 69.3013295305}
Epoch: 541 - {'Reg_MSE': 69.2149499727}
Epoch: 542 - {'Reg_MSE': 69.1289230998}
Epoch: 543 - {'Reg_MSE': 69.0432466638}
Epoch: 544 - {'Reg_MSE': 68.9579184382}
Epoch: 545 - {'Reg_MSE': 68.8729362177}
Epoch: 546 - {'Reg_MSE': 68.7882978177}
Epoch: 547 - {'Reg_MSE': 68.7040010745}
Epoch: 548 - {'Reg_MSE': 68.6200438441}
Epoch: 549 - {'Reg_MSE': 68.5364240031}
Epoch: 550 - {'Reg_MSE': 68.4531394473}
Epoch: 551 - {'Reg_MSE': 68.3701880922}
Epoch: 552 - {'Reg_MSE': 68.2875678723}
Epoch: 553 - {'Reg_MSE': 68.205276741}
Epoch: 554 - {'Reg_MSE': 68.1233126702}
Epoch: 555 - {'Reg_MSE': 68.0416736503}
Epoch: 556 - {'Reg_MSE': 67.9603576898}
Epoch: 557 - {'Reg_MSE': 67.8793628148}
Epoch: 558 - {'Reg_MSE': 67.7986870694}
Epoch: 559 - {'Reg_MSE': 67.7183285146}
Epoch: 560 - {'Reg_MSE': 67.6382852288}
Epoch: 561 - {'Reg_MSE': 67.5585553073}
Epoch: 562 - {'Reg_MSE': 67.479136862}
Epoch: 563 - {'Reg_MSE': 67.4000280213}
Epoch: 564 - {'Reg_MSE': 67.3212269296}
Epoch: 565 - {'Reg_MSE': 67.2427317476}
Epoch: 566 - {'Reg_MSE': 67.1645406516}
Epoch: 567 - {'Reg_MSE': 67.0866518337}
Epoch: 568 - {'Reg_MSE': 67.0090635011}
Epoch: 569 - {'Reg_MSE': 66.9317738765}
Epoch: 570 - {'Reg_MSE': 66.8547811974}
Epoch: 571 - {'Reg_MSE': 66.7780837162}
Epoch: 572 - {'Reg_MSE': 66.7016796997}
Epoch: 573 - {'Reg_MSE': 66.6255674295}
Epoch: 574 - {'Reg_MSE': 66.5497452012}
Epoch: 575 - {'Reg_MSE': 66.4742113244}
Epoch: 576 - {'Reg_MSE': 66.3989641227}
Epoch: 577 - {'Reg_MSE': 66.3240019336}
Epoch: 578 - {'Reg_MSE': 66.2493231078}
Epoch: 579 - {'Reg_MSE': 66.1749260097}
Epoch: 580 - {'Reg_MSE': 66.1008090166}
Epoch: 581 - {'Reg_MSE': 66.0269705191}
Epoch: 582 - {'Reg_MSE': 65.9534089207}
Epoch: 583 - {'Reg_MSE': 65.8801226375}
Epoch: 584 - {'Reg_MSE': 65.8071100983}
Epoch: 585 - {'Reg_MSE': 65.7343697443}
Epoch: 586 - {'Reg_MSE': 65.6619000291}
Epoch: 587 - {'Reg_MSE': 65.5896994181}
Epoch: 588 - {'Reg_MSE': 65.5177663892}
Epoch: 589 - {'Reg_MSE': 65.4460994318}
Epoch: 590 - {'Reg_MSE': 65.3746970472}
Epoch: 591 - {'Reg_MSE': 65.3035577482}
Epoch: 592 - {'Reg_MSE': 65.2326800592}
Epoch: 593 - {'Reg_MSE': 65.1620625157}
Epoch: 594 - {'Reg_MSE': 65.0917036646}
Epoch: 595 - {'Reg_MSE': 65.0216020637}
Epoch: 596 - {'Reg_MSE': 64.951756282}
Epoch: 597 - {'Reg_MSE': 64.882164899}
Epoch: 598 - {'Reg_MSE': 64.8128265053}
Epoch: 599 - {'Reg_MSE': 64.7437397016}
Epoch: 600 - {'Reg_MSE': 64.6749030994}
Epoch: 601 - {'Reg_MSE': 64.6063153206}
Epoch: 602 - {'Reg_MSE': 64.537974997}
Epoch: 603 - {'Reg_MSE': 64.469880771}
Epoch: 604 - {'Reg_MSE': 64.4020312945}
Epoch: 605 - {'Reg_MSE': 64.3344252297}
Epoch: 606 - {'Reg_MSE': 64.2670612484}
Epoch: 607 - {'Reg_MSE': 64.1999380323}
Epoch: 608 - {'Reg_MSE': 64.1330542725}
Epoch: 609 - {'Reg_MSE': 64.0664086695}
Epoch: 610 - {'Reg_MSE': 63.9999999336}
Epoch: 611 - {'Reg_MSE': 63.9338267841}
Epoch: 612 - {'Reg_MSE': 63.8678879494}
Epoch: 613 - {'Reg_MSE': 63.8021821673}
Epoch: 614 - {'Reg_MSE': 63.7367081845}
Epoch: 615 - {'Reg_MSE': 63.6714647564}
Epoch: 616 - {'Reg_MSE': 63.6064506476}
Epoch: 617 - {'Reg_MSE': 63.5416646312}
Epoch: 618 - {'Reg_MSE': 63.4771054889}
Epoch: 619 - {'Reg_MSE': 63.4127720112}
Epoch: 620 - {'Reg_MSE': 63.3486629969}
Epoch: 621 - {'Reg_MSE': 63.2847772533}
Epoch: 622 - {'Reg_MSE': 63.2211135958}
Epoch: 623 - {'Reg_MSE': 63.1576708482}
Epoch: 624 - {'Reg_MSE': 63.0944478426}
Epoch: 625 - {'Reg_MSE': 63.0314434188}
Epoch: 626 - {'Reg_MSE': 62.9686564248}
Epoch: 627 - {'Reg_MSE': 62.9060857166}
Epoch: 628 - {'Reg_MSE': 62.8437301578}
Epoch: 629 - {'Reg_MSE': 62.78158862}
Epoch: 630 - {'Reg_MSE': 62.7196599823}
Epoch: 631 - {'Reg_MSE': 62.6579431315}
Epoch: 632 - {'Reg_MSE': 62.5964369619}
Epoch: 633 - {'Reg_MSE': 62.5351403753}
Epoch: 634 - {'Reg_MSE': 62.474052281}
Epoch: 635 - {'Reg_MSE': 62.4131715954}
Epoch: 636 - {'Reg_MSE': 62.3524972423}
Epoch: 637 - {'Reg_MSE': 62.2920281528}
Epoch: 638 - {'Reg_MSE': 62.2317632649}
Epoch: 639 - {'Reg_MSE': 62.1717015239}
Epoch: 640 - {'Reg_MSE': 62.1118418819}
Epoch: 641 - {'Reg_MSE': 62.052183298}
Epoch: 642 - {'Reg_MSE': 61.9927247382}
Epoch: 643 - {'Reg_MSE': 61.9334651754}
Epoch: 644 - {'Reg_MSE': 61.8744035891}
Epoch: 645 - {'Reg_MSE': 61.8155389655}
Epoch: 646 - {'Reg_MSE': 61.7568702976}
Epoch: 647 - {'Reg_MSE': 61.6983965847}
Epoch: 648 - {'Reg_MSE': 61.6401168329}
Epoch: 649 - {'Reg_MSE': 61.5820300546}
Epoch: 650 - {'Reg_MSE': 61.5241352685}
Epoch: 651 - {'Reg_MSE': 61.4664315}
Epoch: 652 - {'Reg_MSE': 61.4089177803}
Epoch: 653 - {'Reg_MSE': 61.3515931473}
Epoch: 654 - {'Reg_MSE': 61.2944566448}
Epoch: 655 - {'Reg_MSE': 61.2375073229}
Epoch: 656 - {'Reg_MSE': 61.1807442375}
Epoch: 657 - {'Reg_MSE': 61.1241664509}
Epoch: 658 - {'Reg_MSE': 61.0677730311}
Epoch: 659 - {'Reg_MSE': 61.0115630521}
Epoch: 660 - {'Reg_MSE': 60.9555355939}
Epoch: 661 - {'Reg_MSE': 60.8996897421}
Epoch: 662 - {'Reg_MSE': 60.8440245882}
Epoch: 663 - {'Reg_MSE': 60.7885392294}
Epoch: 664 - {'Reg_MSE': 60.7332327687}
Epoch: 665 - {'Reg_MSE': 60.6781043146}
Epoch: 666 - {'Reg_MSE': 60.6231529812}
Epoch: 667 - {'Reg_MSE': 60.5683778882}
Epoch: 668 - {'Reg_MSE': 60.5137781607}
Epoch: 669 - {'Reg_MSE': 60.4593529294}
Epoch: 670 - {'Reg_MSE': 60.4051013303}
Epoch: 671 - {'Reg_MSE': 60.3510225047}
Epoch: 672 - {'Reg_MSE': 60.2971155995}
Epoch: 673 - {'Reg_MSE': 60.2433797666}
Epoch: 674 - {'Reg_MSE': 60.1898141632}
Epoch: 675 - {'Reg_MSE': 60.1364179518}
Epoch: 676 - {'Reg_MSE': 60.0831902999}
Epoch: 677 - {'Reg_MSE': 60.0301303803}
Epoch: 678 - {'Reg_MSE': 59.9772373708}
Epoch: 679 - {'Reg_MSE': 59.924510454}
Epoch: 680 - {'Reg_MSE': 59.871948818}
Epoch: 681 - {'Reg_MSE': 59.8195516553}
Epoch: 682 - {'Reg_MSE': 59.7673181637}
Epoch: 683 - {'Reg_MSE': 59.7152475458}
Epoch: 684 - {'Reg_MSE': 59.6633390089}
Epoch: 685 - {'Reg_MSE': 59.6115917653}
Epoch: 686 - {'Reg_MSE': 59.5600050319}
Epoch: 687 - {'Reg_MSE': 59.5085780305}
Epoch: 688 - {'Reg_MSE': 59.4573099874}
Epoch: 689 - {'Reg_MSE': 59.4062001337}
Epoch: 690 - {'Reg_MSE': 59.3552477051}
Epoch: 691 - {'Reg_MSE': 59.3044519418}
Epoch: 692 - {'Reg_MSE': 59.2538120888}
Epoch: 693 - {'Reg_MSE': 59.2033273953}
Epoch: 694 - {'Reg_MSE': 59.1529971152}
Epoch: 695 - {'Reg_MSE': 59.1028205068}
Epoch: 696 - {'Reg_MSE': 59.0527968328}
Epoch: 697 - {'Reg_MSE': 59.0029253604}
Epoch: 698 - {'Reg_MSE': 58.9532053609}
Epoch: 699 - {'Reg_MSE': 58.9036361101}
Epoch: 700 - {'Reg_MSE': 58.8542168883}
Epoch: 701 - {'Reg_MSE': 58.8049469795}
Epoch: 702 - {'Reg_MSE': 58.7558256726}
Epoch: 703 - {'Reg_MSE': 58.7068522601}
Epoch: 704 - {'Reg_MSE': 58.6580260391}
Epoch: 705 - {'Reg_MSE': 58.6093463106}
Epoch: 706 - {'Reg_MSE': 58.5608123798}
Epoch: 707 - {'Reg_MSE': 58.5124235559}
Epoch: 708 - {'Reg_MSE': 58.4641791523}
Epoch: 709 - {'Reg_MSE': 58.4160784863}
Epoch: 710 - {'Reg_MSE': 58.3681208792}
Epoch: 711 - {'Reg_MSE': 58.3203056563}
Epoch: 712 - {'Reg_MSE': 58.2726321468}
Epoch: 713 - {'Reg_MSE': 58.2250996838}
Epoch: 714 - {'Reg_MSE': 58.1777076044}
Epoch: 715 - {'Reg_MSE': 58.1304552493}
Epoch: 716 - {'Reg_MSE': 58.0833419633}
Epoch: 717 - {'Reg_MSE': 58.0363670948}
Epoch: 718 - {'Reg_MSE': 57.9895299961}
Epoch: 719 - {'Reg_MSE': 57.9428300232}
Epoch: 720 - {'Reg_MSE': 57.8962665358}
Epoch: 721 - {'Reg_MSE': 57.8498388972}
Epoch: 722 - {'Reg_MSE': 57.8035464746}
Epoch: 723 - {'Reg_MSE': 57.7573886387}
Epoch: 724 - {'Reg_MSE': 57.7113647638}
Epoch: 725 - {'Reg_MSE': 57.6654742277}
Epoch: 726 - {'Reg_MSE': 57.6197164121}
Epoch: 727 - {'Reg_MSE': 57.5740907019}
Epoch: 728 - {'Reg_MSE': 57.5285964857}
Epoch: 729 - {'Reg_MSE': 57.4832331554}
Epoch: 730 - {'Reg_MSE': 57.4380001067}
Epoch: 731 - {'Reg_MSE': 57.3928967384}
Epoch: 732 - {'Reg_MSE': 57.347922453}
Epoch: 733 - {'Reg_MSE': 57.3030766562}
Epoch: 734 - {'Reg_MSE': 57.2583587572}
Epoch: 735 - {'Reg_MSE': 57.2137681685}
Epoch: 736 - {'Reg_MSE': 57.1693043059}
Epoch: 737 - {'Reg_MSE': 57.1249665885}
Epoch: 738 - {'Reg_MSE': 57.0807544388}
Epoch: 739 - {'Reg_MSE': 57.0366672826}
Epoch: 740 - {'Reg_MSE': 56.9927045487}
Epoch: 741 - {'Reg_MSE': 56.9488656692}
Epoch: 742 - {'Reg_MSE': 56.9051500797}
Epoch: 743 - {'Reg_MSE': 56.8615572185}
Epoch: 744 - {'Reg_MSE': 56.8180865274}
Epoch: 745 - {'Reg_MSE': 56.7747374512}
Epoch: 746 - {'Reg_MSE': 56.7315094379}
Epoch: 747 - {'Reg_MSE': 56.6884019385}
Epoch: 748 - {'Reg_MSE': 56.6454144071}
Epoch: 749 - {'Reg_MSE': 56.602546301}
Epoch: 750 - {'Reg_MSE': 56.5597970802}
Epoch: 751 - {'Reg_MSE': 56.517166208}
Epoch: 752 - {'Reg_MSE': 56.4746531507}
Epoch: 753 - {'Reg_MSE': 56.4322573773}
Epoch: 754 - {'Reg_MSE': 56.3899783602}
Epoch: 755 - {'Reg_MSE': 56.3478155742}
Epoch: 756 - {'Reg_MSE': 56.3057684975}
Epoch: 757 - {'Reg_MSE': 56.2638366109}
Epoch: 758 - {'Reg_MSE': 56.2220193982}
Epoch: 759 - {'Reg_MSE': 56.180316346}
Epoch: 760 - {'Reg_MSE': 56.1387269439}
Epoch: 761 - {'Reg_MSE': 56.097250684}
Epoch: 762 - {'Reg_MSE': 56.0558870615}
Epoch: 763 - {'Reg_MSE': 56.0146355743}
Epoch: 764 - {'Reg_MSE': 55.9734957231}
Epoch: 765 - {'Reg_MSE': 55.9324670112}
Epoch: 766 - {'Reg_MSE': 55.8915489449}
Epoch: 767 - {'Reg_MSE': 55.8507410329}
Epoch: 768 - {'Reg_MSE': 55.8100427868}
Epoch: 769 - {'Reg_MSE': 55.769453721}
Epoch: 770 - {'Reg_MSE': 55.7289733522}
Epoch: 771 - {'Reg_MSE': 55.6886012001}
Epoch: 772 - {'Reg_MSE': 55.6483367868}
Epoch: 773 - {'Reg_MSE': 55.6081796372}
Epoch: 774 - {'Reg_MSE': 55.5681292786}
Epoch: 775 - {'Reg_MSE': 55.5281852411}
Epoch: 776 - {'Reg_MSE': 55.4883470571}
Epoch: 777 - {'Reg_MSE': 55.4486142618}
Epoch: 778 - {'Reg_MSE': 55.4089863928}
Epoch: 779 - {'Reg_MSE': 55.3694629902}
Epoch: 780 - {'Reg_MSE': 55.3300435967}
Epoch: 781 - {'Reg_MSE': 55.2907277575}
Epoch: 782 - {'Reg_MSE': 55.2515150201}
Epoch: 783 - {'Reg_MSE': 55.2124049346}
Epoch: 784 - {'Reg_MSE': 55.1733970535}
Epoch: 785 - {'Reg_MSE': 55.1344909316}
Epoch: 786 - {'Reg_MSE': 55.0956861265}
Epoch: 787 - {'Reg_MSE': 55.0569821976}
Epoch: 788 - {'Reg_MSE': 55.0183787073}
Epoch: 789 - {'Reg_MSE': 54.9798752199}
Epoch: 790 - {'Reg_MSE': 54.9414713022}
Epoch: 791 - {'Reg_MSE': 54.9031665234}
Epoch: 792 - {'Reg_MSE': 54.864960455}
Epoch: 793 - {'Reg_MSE': 54.8268526707}
Epoch: 794 - {'Reg_MSE': 54.7888427467}
Epoch: 795 - {'Reg_MSE': 54.7509302611}
Epoch: 796 - {'Reg_MSE': 54.7131147948}
Epoch: 797 - {'Reg_MSE': 54.6753959305}
Epoch: 798 - {'Reg_MSE': 54.6377732533}
Epoch: 799 - {'Reg_MSE': 54.6002463506}
Epoch: 800 - {'Reg_MSE': 54.5628148118}
Epoch: 801 - {'Reg_MSE': 54.5254782287}
Epoch: 802 - {'Reg_MSE': 54.4882361951}
Epoch: 803 - {'Reg_MSE': 54.4510883073}
Epoch: 804 - {'Reg_MSE': 54.4140341633}
Epoch: 805 - {'Reg_MSE': 54.3770733636}
Epoch: 806 - {'Reg_MSE': 54.3402055106}
Epoch: 807 - {'Reg_MSE': 54.303430209}
Epoch: 808 - {'Reg_MSE': 54.2667470654}
Epoch: 809 - {'Reg_MSE': 54.2301556887}
Epoch: 810 - {'Reg_MSE': 54.1936556897}
Epoch: 811 - {'Reg_MSE': 54.1572466815}
Epoch: 812 - {'Reg_MSE': 54.1209282789}
Epoch: 813 - {'Reg_MSE': 54.0847000991}
Epoch: 814 - {'Reg_MSE': 54.0485617611}
Epoch: 815 - {'Reg_MSE': 54.0125128859}
Epoch: 816 - {'Reg_MSE': 53.9765530967}
Epoch: 817 - {'Reg_MSE': 53.9406820186}
Epoch: 818 - {'Reg_MSE': 53.9048992785}
Epoch: 819 - {'Reg_MSE': 53.8692045055}
Epoch: 820 - {'Reg_MSE': 53.8335973306}
Epoch: 821 - {'Reg_MSE': 53.7980773867}
Epoch: 822 - {'Reg_MSE': 53.7626443087}
Epoch: 823 - {'Reg_MSE': 53.7272977332}
Epoch: 824 - {'Reg_MSE': 53.6920372991}
Epoch: 825 - {'Reg_MSE': 53.6568626469}
Epoch: 826 - {'Reg_MSE': 53.6217734189}
Epoch: 827 - {'Reg_MSE': 53.5867692597}
Epoch: 828 - {'Reg_MSE': 53.5518498153}
Epoch: 829 - {'Reg_MSE': 53.5170147338}
Epoch: 830 - {'Reg_MSE': 53.4822636652}
Epoch: 831 - {'Reg_MSE': 53.4475962611}
Epoch: 832 - {'Reg_MSE': 53.413012175}
Epoch: 833 - {'Reg_MSE': 53.3785110624}
Epoch: 834 - {'Reg_MSE': 53.3440925804}
Epoch: 835 - {'Reg_MSE': 53.3097563879}
Epoch: 836 - {'Reg_MSE': 53.2755021456}
Epoch: 837 - {'Reg_MSE': 53.241329516}
Epoch: 838 - {'Reg_MSE': 53.2072381634}
Epoch: 839 - {'Reg_MSE': 53.1732277537}
Epoch: 840 - {'Reg_MSE': 53.1392979546}
Epoch: 841 - {'Reg_MSE': 53.1054484356}
Epoch: 842 - {'Reg_MSE': 53.0716788678}
Epoch: 843 - {'Reg_MSE': 53.0379889241}
Epoch: 844 - {'Reg_MSE': 53.0043782791}
Epoch: 845 - {'Reg_MSE': 52.9708466089}
Epoch: 846 - {'Reg_MSE': 52.9373935915}
Epoch: 847 - {'Reg_MSE': 52.9040189064}
Epoch: 848 - {'Reg_MSE': 52.8707222349}
Epoch: 849 - {'Reg_MSE': 52.8375032599}
Epoch: 850 - {'Reg_MSE': 52.8043616658}
Epoch: 851 - {'Reg_MSE': 52.7712971388}
Epoch: 852 - {'Reg_MSE': 52.7383093667}
Epoch: 853 - {'Reg_MSE': 52.7053980388}
Epoch: 854 - {'Reg_MSE': 52.672562846}
Epoch: 855 - {'Reg_MSE': 52.639803481}
Epoch: 856 - {'Reg_MSE': 52.6071196377}
Epoch: 857 - {'Reg_MSE': 52.574511012}
Epoch: 858 - {'Reg_MSE': 52.5419773011}
Epoch: 859 - {'Reg_MSE': 52.5095182037}
Epoch: 860 - {'Reg_MSE': 52.4771334202}
Epoch: 861 - {'Reg_MSE': 52.4448226526}
Epoch: 862 - {'Reg_MSE': 52.4125856041}
Epoch: 863 - {'Reg_MSE': 52.3804219797}
Epoch: 864 - {'Reg_MSE': 52.3483314859}
Epoch: 865 - {'Reg_MSE': 52.3163138306}
Epoch: 866 - {'Reg_MSE': 52.2843687231}
Epoch: 867 - {'Reg_MSE': 52.2524958744}
Epoch: 868 - {'Reg_MSE': 52.2206949968}
Epoch: 869 - {'Reg_MSE': 52.1889658041}
Epoch: 870 - {'Reg_MSE': 52.1573080117}
Epoch: 871 - {'Reg_MSE': 52.1257213363}
Epoch: 872 - {'Reg_MSE': 52.094205496}
Epoch: 873 - {'Reg_MSE': 52.0627602105}
Epoch: 874 - {'Reg_MSE': 52.0313852007}
Epoch: 875 - {'Reg_MSE': 52.0000801891}
Epoch: 876 - {'Reg_MSE': 51.9688448995}
Epoch: 877 - {'Reg_MSE': 51.9376790571}
Epoch: 878 - {'Reg_MSE': 51.9065823887}
Epoch: 879 - {'Reg_MSE': 51.8755546221}
Epoch: 880 - {'Reg_MSE': 51.8445954868}
Epoch: 881 - {'Reg_MSE': 51.8137047136}
Epoch: 882 - {'Reg_MSE': 51.7828820344}
Epoch: 883 - {'Reg_MSE': 51.7521271828}
Epoch: 884 - {'Reg_MSE': 51.7214398936}
Epoch: 885 - {'Reg_MSE': 51.6908199028}
Epoch: 886 - {'Reg_MSE': 51.660266948}
Epoch: 887 - {'Reg_MSE': 51.6297807679}
Epoch: 888 - {'Reg_MSE': 51.5993611027}
Epoch: 889 - {'Reg_MSE': 51.5690076936}
Epoch: 890 - {'Reg_MSE': 51.5387202834}
Epoch: 891 - {'Reg_MSE': 51.508498616}
Epoch: 892 - {'Reg_MSE': 51.4783424368}
Epoch: 893 - {'Reg_MSE': 51.4482514922}
Epoch: 894 - {'Reg_MSE': 51.41822553}
Epoch: 895 - {'Reg_MSE': 51.3882642994}
Epoch: 896 - {'Reg_MSE': 51.3583675506}
Epoch: 897 - {'Reg_MSE': 51.3285350351}
Epoch: 898 - {'Reg_MSE': 51.2987665058}
Epoch: 899 - {'Reg_MSE': 51.2690617167}
Epoch: 900 - {'Reg_MSE': 51.2394204231}
Epoch: 901 - {'Reg_MSE': 51.2098423814}
Epoch: 902 - {'Reg_MSE': 51.1803273493}
Epoch: 903 - {'Reg_MSE': 51.1508750857}
Epoch: 904 - {'Reg_MSE': 51.1214853507}
Epoch: 905 - {'Reg_MSE': 51.0921579055}
Epoch: 906 - {'Reg_MSE': 51.0628925127}
Epoch: 907 - {'Reg_MSE': 51.0336889357}
Epoch: 908 - {'Reg_MSE': 51.0045469396}
Epoch: 909 - {'Reg_MSE': 50.9754662901}
Epoch: 910 - {'Reg_MSE': 50.9464467545}
Epoch: 911 - {'Reg_MSE': 50.917488101}
Epoch: 912 - {'Reg_MSE': 50.888590099}
Epoch: 913 - {'Reg_MSE': 50.8597525191}
Epoch: 914 - {'Reg_MSE': 50.830975133}
Epoch: 915 - {'Reg_MSE': 50.8022577135}
Epoch: 916 - {'Reg_MSE': 50.7736000346}
Epoch: 917 - {'Reg_MSE': 50.7450018712}
Epoch: 918 - {'Reg_MSE': 50.7164629996}
Epoch: 919 - {'Reg_MSE': 50.6879831969}
Epoch: 920 - {'Reg_MSE': 50.6595622415}
Epoch: 921 - {'Reg_MSE': 50.6311999129}
Epoch: 922 - {'Reg_MSE': 50.6028959916}
Epoch: 923 - {'Reg_MSE': 50.5746502591}
Epoch: 924 - {'Reg_MSE': 50.5464624982}
Epoch: 925 - {'Reg_MSE': 50.5183324925}
Epoch: 926 - {'Reg_MSE': 50.4902600268}
Epoch: 927 - {'Reg_MSE': 50.462244887}
Epoch: 928 - {'Reg_MSE': 50.4342868598}
Epoch: 929 - {'Reg_MSE': 50.4063857334}
Epoch: 930 - {'Reg_MSE': 50.3785412965}
Epoch: 931 - {'Reg_MSE': 50.3507533392}
Epoch: 932 - {'Reg_MSE': 50.3230216525}
Epoch: 933 - {'Reg_MSE': 50.2953460284}
Epoch: 934 - {'Reg_MSE': 50.2677262598}
Epoch: 935 - {'Reg_MSE': 50.240162141}
Epoch: 936 - {'Reg_MSE': 50.2126534668}
Epoch: 937 - {'Reg_MSE': 50.1852000334}
Epoch: 938 - {'Reg_MSE': 50.1578016377}
Epoch: 939 - {'Reg_MSE': 50.1304580777}
Epoch: 940 - {'Reg_MSE': 50.1031691525}
Epoch: 941 - {'Reg_MSE': 50.0759346619}
Epoch: 942 - {'Reg_MSE': 50.048754407}
Epoch: 943 - {'Reg_MSE': 50.0216281895}
Epoch: 944 - {'Reg_MSE': 49.9945558123}
Epoch: 945 - {'Reg_MSE': 49.9675370792}
Epoch: 946 - {'Reg_MSE': 49.9405717949}
Epoch: 947 - {'Reg_MSE': 49.9136597652}
Epoch: 948 - {'Reg_MSE': 49.8868007966}
Epoch: 949 - {'Reg_MSE': 49.8599946966}
Epoch: 950 - {'Reg_MSE': 49.8332412738}
Epoch: 951 - {'Reg_MSE': 49.8065403374}
Epoch: 952 - {'Reg_MSE': 49.7798916979}
Epoch: 953 - {'Reg_MSE': 49.7532951664}
Epoch: 954 - {'Reg_MSE': 49.7267505551}
Epoch: 955 - {'Reg_MSE': 49.7002576769}
Epoch: 956 - {'Reg_MSE': 49.6738163458}
Epoch: 957 - {'Reg_MSE': 49.6474263765}
Epoch: 958 - {'Reg_MSE': 49.6210875848}
Epoch: 959 - {'Reg_MSE': 49.5947997872}
Epoch: 960 - {'Reg_MSE': 49.5685628012}
Epoch: 961 - {'Reg_MSE': 49.5423764451}
Epoch: 962 - {'Reg_MSE': 49.516240538}
Epoch: 963 - {'Reg_MSE': 49.4901549}
Epoch: 964 - {'Reg_MSE': 49.464119352}
Epoch: 965 - {'Reg_MSE': 49.4381337157}
Epoch: 966 - {'Reg_MSE': 49.4121978138}
Epoch: 967 - {'Reg_MSE': 49.3863114696}
Epoch: 968 - {'Reg_MSE': 49.3604745075}
Epoch: 969 - {'Reg_MSE': 49.3346867525}
Epoch: 970 - {'Reg_MSE': 49.3089480307}
Epoch: 971 - {'Reg_MSE': 49.2832581686}
Epoch: 972 - {'Reg_MSE': 49.257616994}
Epoch: 973 - {'Reg_MSE': 49.2320243353}
Epoch: 974 - {'Reg_MSE': 49.2064800216}
Epoch: 975 - {'Reg_MSE': 49.1809838829}
Epoch: 976 - {'Reg_MSE': 49.1555357501}
Epoch: 977 - {'Reg_MSE': 49.1301354548}
Epoch: 978 - {'Reg_MSE': 49.1047828293}
Epoch: 979 - {'Reg_MSE': 49.079477707}
Epoch: 980 - {'Reg_MSE': 49.0542199217}
Epoch: 981 - {'Reg_MSE': 49.0290093082}
Epoch: 982 - {'Reg_MSE': 49.0038457021}
Epoch: 983 - {'Reg_MSE': 48.9787289396}
Epoch: 984 - {'Reg_MSE': 48.9536588579}
Epoch: 985 - {'Reg_MSE': 48.9286352948}
Epoch: 986 - {'Reg_MSE': 48.9036580889}
Epoch: 987 - {'Reg_MSE': 48.8787270795}
Epoch: 988 - {'Reg_MSE': 48.8538421067}
Epoch: 989 - {'Reg_MSE': 48.8290030115}
Epoch: 990 - {'Reg_MSE': 48.8042096353}
Epoch: 991 - {'Reg_MSE': 48.7794618206}
Epoch: 992 - {'Reg_MSE': 48.7547594104}
Epoch: 993 - {'Reg_MSE': 48.7301022484}
Epoch: 994 - {'Reg_MSE': 48.7054901793}
Epoch: 995 - {'Reg_MSE': 48.6809230483}
Epoch: 996 - {'Reg_MSE': 48.6564007014}
Epoch: 997 - {'Reg_MSE': 48.6319229852}
Epoch: 998 - {'Reg_MSE': 48.6074897471}
Epoch: 999 - {'Reg_MSE': 48.5831008353}
Epoch: 1000 - {'Reg_MSE': 48.5587560986}
Mean Square Error = 48.5587
Implement a Homegrown Logistic Regression model. Extend the loss function from CXE to CXE + MSE, i.e., make it a complex multitask loss function where the resulting model predicts the class and bounding box coordinates at the same time
import warnings
warnings.filterwarnings('ignore')
class LogisticRegressionHomegrown(object):
def __init__(self):
"""
Constructor for the homgrown Logistic Regression
Args:
None
Return:
None
"""
self.coef_r = None # weight vector
self.intercept_r = None # bias term
self.coef_c = None # weight vector
self.intercept_c = None # bias term
self._thetaReg = None # augmented weight vector, i.e., bias + weights
# this allows to treat all decision variables homogeneously
self._thetaClass = None
self.history = {"CXE+MSE_train": [],
"Class_train_acc": [],
"Class_train_CXE":[],
"Reg_train_MSE":[],
"val_CXE+MSE":[],
"Class_val_CXE":[],
"Class_val_acc": [],
"Reg_val_MSE":[]}
# def normalize(X):
# # X --> Input.
# # m-> number of training examples
# # n-> number of features
# m, n = X.shape
# # Normalizing all the n features of X.
# for i in range(n):
# X = (X - X.mean(axis=0))/X.std(axis=0)
# # scale data
# np.random.seed(42)
# if np.max(X) > 4.:
# X = X.astype(np.float32) / 255.
# return X
# gradient for regressor
def _gradReg(self, X, y):
# number of training examples
n = X.shape[0]
# get scores for each class and example
# 2D matrix
scores = self._predict_raw(X,val=1)
gradient = np.dot(X.T, scores) / n
return gradient
# gradient for classifier
def _gradClass(self, X, y):
"""
Calculates the gradient of the Logistic Regression
objective function
Args:
X(ndarray): train objects
y(ndarray): answers for train objects
Return:
grad(ndarray): gradient
"""
# number of training examples
n = X.shape[0]
# get scores for each class and example
# 2D matrix
scores = self._predict_raw(X,val=2)
# transform scores to probabilities
# softmax
probs = 1.0/(1 + np.exp(-scores))
#probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
# error
probs[range(n),y] -= 1
# (1/m)*np.dot(X.T, (scores - y))
# gradient
gradient = np.dot(X.T, probs) / n
return gradient
def _gd(self, X_r, y_r,X_c,y_c, max_iter, alpha, X_val_r, y_val_r, X_val_c, y_val_c):
"""
Runs Full GD and logs error, weigths, gradient at every step
Args:
X(ndarray): train objects
y(ndarray): answers for train objects
max_iter(int): number of weight updates
alpha(floar): step size in direction of gradient
Return:
None
"""
for i in range(max_iter):
metrics = self.score(X_r, y_r, X_c, y_c)
print("Epoch: ",i+1,"- ", metrics)
self.history["CXE+MSE_train"].append(metrics["CXE+MSE"])
self.history["Class_train_acc"].append(metrics["Class_acc"])
self.history["Class_train_CXE"].append(metrics["Class_CXE"])
self.history["Reg_train_MSE"].append(metrics["Reg_MSE"])
if X_val_r is not None and X_val_c is not None:
metrics_val = self.score(X_val_r, y_val_r,X_val_c, y_val_c)
self.history["val_CXE+MSE"].append(metrics_val["CXE+MSE"])
self.history["Class_val_CXE"].append(metrics_val["Class_CXE"])
self.history["Class_val_acc"].append(metrics_val["Class_acc"])
self.history["Reg_val_MSE"].append(metrics_val["Reg_MSE"])
# calculate gradient for regressor
grad_reg = self._gradReg(X_r, y_r)
# calculate gradient for classifier
grad_class = self._gradClass(X_c, y_c)
# do gradient step
self._thetaReg -= alpha * grad_reg
# do gradient step
self._thetaClass -= alpha * grad_class
def fit(self, X_r,y_r,X_c,y_c, max_iter=1000, alpha=0.05, val_data_r=None, val_data_c=None):
"""
Public API to fit Logistic regression model
Args:
X(ndarray): train objects
y(ndarray): answers for train objects
max_iter(int): number of weight updates
alpha(floar): step size in direction of gradient
Return:
None
"""
# X_r = normalize(X_r)
# X_c = normalize(X_c)
# Augment the data with the bias term.
# So we can treat the the input variables and the bias term homogeneously
# from a vectorization perspective
X_r = np.c_[np.ones(X_r.shape[0]), X_r]
if val_data_r is not None:
X_val_r, y_val_r = val_data_r
X_val_r = np.c_[np.ones(X_val_r.shape[0]), X_val_r]
else:
X_val_r = None
y_val_r = None
# initialize if the first step
if self._thetaReg is None:
self._thetaReg = np.random.rand(X_r.shape[1], 4)
#classification
X_c = np.c_[np.ones(X_c.shape[0]), X_c]
if val_data_c is not None:
X_val_c, y_val_c = val_data_c
X_val_c = np.c_[np.ones(X_val_c.shape[0]), X_val_c]
else:
X_val_c = None
y_val_c = None
# initialize if the first step
if self._thetaClass is None:
self._thetaClass = np.random.rand(X_c.shape[1], len(np.unique(y_c)))
# do full gradient descent
self._gd(X_r, y_r,X_c,y_c, max_iter, alpha, X_val_r, y_val_r, X_val_c, y_val_c)
# get final weigths and bias
self.intercept_r = self._thetaReg[0]
self.coef_r = self._thetaReg[1:]
# get final weigths and bias
self.intercept_c = self._thetaClass[0]
self.coef_c = self._thetaClass[1:]
def score(self, X_r, y_r, X_c, y_c):
# number of training samples
n1 = X_r.shape[0]
n2 = X_c.shape[0]
# get scores
scores_r = self._predict_raw(X_r,val=1)
scores_c = self._predict_raw(X_c,val=2)
pred_r=scores_r
# for i in range(0,len(scores_c)):
# pred_c[i]=scores_c[np.argmax(scores_c[i])]
# print("scores_c",scores_c)
exp=2.73
probs = 1.0/(1 + exp**(-scores_c))
# print("--------probs-------",probs)
# ind=np.argmax(scores_c,axis=1)
pred_c=[]
for i,ind in enumerate(np.argmax(scores_c,axis=1)):
pred_c.append(int(probs[i][ind]))
# probs[range(n2),ind] -= 1
# print("pred_r",pred_r,"\npred_c",pred_c,"\ny_c",y_c)
#pred_c = [1 if i > 0.5 else 0 for i in probs.all()]
# accuracy
acc = accuracy_score(y_c, np.array(pred_c))
pred_c=np.array(pred_c)
# trasnform scores to probabilities
exp_scores = np.exp(-scores_c)
probs1 = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
# print('probs1',probs1,"\nexp_scores",exp_scores)
# logloss per each example
corect_logprobs = (probs[range(n2),y_c])
# print("\ncorrect log prob",corect_logprobs)
# print(corect_logprobs)
# total mean logloss
data_loss = np.sum(corect_logprobs) / n2
loss=0
for each in corect_logprobs:
if (each!=1 and each!=0):
loss +=y_c * (-np.log (each)) + (1-y_c) * (-np.log (1-each))
# loss = y_c * np.log (corect_logprobs) + (1-y_c) * np.log (1-corect_logprobs) if ( corect_logprobs != 1 and corect_logprobs != 0 ) else 0
#loss = -np.mean(y_c*(np.log(corect_logprobs)) - (1-y_c)*np.log(1-corect_logprobs))
# final metrics
metrics = {"Class_acc": acc,
"CXE+MSE": np.round(np.mean(loss),decimals=10)+np.round(mean_squared_error(y_r, pred_r),decimals=10),
"Reg_MSE": np.round(mean_squared_error(y_r, pred_r),decimals=10),
"Class_CXE":np.round(np.mean(loss),decimals=10)}
# print(metrics)
return metrics
def _predict_raw(self, X, val):
"""
Computes scores for each class and each object in X
Args:
X(ndarray): objects
Return:
scores(ndarray): scores for each class and object
"""
if val == 1:
# check whether X has appended bias feature or not
if X.shape[1] == len(self._thetaReg):
scores = np.dot(X, self._thetaReg)
else:
scores = np.dot(X, self.coef_r) + self.intercept_r
else:
# check whether X has appended bias feature or not
if X.shape[1] == len(self._thetaClass):
scores = np.dot(X, self._thetaClass)
else:
scores = np.dot(X, self.coef_c) + self.intercept_c
return scores
def predict(self, X):
"""
Predicts class for each object in X
Args:
X(ndarray): objects
Return:
pred(ndarray): class for each object
"""
# get scores for each class
scores = self._predict_raw(X,val=2)
# choose class with maximum score
pred = np.argmax(scores, axis=1)
return pred
model_lr_homegrown = LogisticRegressionHomegrown()
#model_lr_homegrown.fit(X_train_r, y_train_r, max_iter=100, alpha=0.1,val_data=[X_valid_r,y_valid_r])
model_lr_homegrown.fit(X_train_r, y_train_r,X_train_c, y_train_c, max_iter=1000, alpha=0.00005,val_data_r=[X_valid_r,y_valid_r],val_data_c=[X_valid_c,y_valid_c])
Epoch: 1 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 1180968.4623501315, 'Reg_MSE': 493148.9830658014, 'Class_CXE': 687819.47928433}
Epoch: 2 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 1148203.6913394053, 'Reg_MSE': 460397.2814392102, 'Class_CXE': 687806.4099001951}
Epoch: 3 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 1117643.092888654, 'Reg_MSE': 429849.7349677899, 'Class_CXE': 687793.3579208641}
Epoch: 4 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 1089137.5610523035, 'Reg_MSE': 401357.2733482536, 'Class_CXE': 687780.2877040498}
Epoch: 5 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 1062548.1602737224, 'Reg_MSE': 374780.9272708056, 'Class_CXE': 687767.2330029169}
Epoch: 6 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 1038119.6905391545, 'Reg_MSE': 349991.1434686164, 'Class_CXE': 688128.547070538}
Epoch: 7 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 1014982.5956713611, 'Reg_MSE': 326867.1462275835, 'Class_CXE': 688115.4494437776}
Epoch: 8 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 993398.7342473181, 'Reg_MSE': 305296.342204617, 'Class_CXE': 688102.3920427011}
Epoch: 9 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 973263.0996991112, 'Reg_MSE': 285173.7656164684, 'Class_CXE': 688089.3340826428}
Epoch: 10 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 954477.8350479465, 'Reg_MSE': 266401.56106048, 'Class_CXE': 688076.2739874666}
Epoch: 11 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 936951.7101659675, 'Reg_MSE': 248888.5014143865, 'Class_CXE': 688063.208751581}
Epoch: 12 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 920974.0060685603, 'Reg_MSE': 232549.5384355186, 'Class_CXE': 688424.4676330417}
Epoch: 13 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 905716.7737226184, 'Reg_MSE': 217305.3838411835, 'Class_CXE': 688411.3898814349}
Epoch: 14 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 891480.439492421, 'Reg_MSE': 203082.1188024715, 'Class_CXE': 688398.3206899494}
Epoch: 15 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 878196.0835276194, 'Reg_MSE': 189810.8299240448, 'Class_CXE': 688385.2536035746}
Epoch: 16 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 865799.4577882767, 'Reg_MSE': 177427.2699131956, 'Class_CXE': 688372.1878750811}
Epoch: 17 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 854230.6552585905, 'Reg_MSE': 165871.5412633672, 'Class_CXE': 688359.1139952233}
Epoch: 18 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 843433.8501039273, 'Reg_MSE': 155087.8013909443, 'Class_CXE': 688346.048712983}
Epoch: 19 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 833356.9225507298, 'Reg_MSE': 145023.9877700321, 'Class_CXE': 688332.9347806977}
Epoch: 20 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 823951.4331957735, 'Reg_MSE': 135631.5617086672, 'Class_CXE': 688319.8714871063}
Epoch: 21 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 815172.0743377213, 'Reg_MSE': 126865.2695019425, 'Class_CXE': 688306.8048357788}
Epoch: 22 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 806976.6294053831, 'Reg_MSE': 118682.919783295, 'Class_CXE': 688293.7096220881}
Epoch: 23 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 799325.818677856, 'Reg_MSE': 111045.17597519, 'Class_CXE': 688280.642702666}
Epoch: 24 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 792182.7442842717, 'Reg_MSE': 103915.3628149608, 'Class_CXE': 688267.381469311}
Epoch: 25 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 785513.6066165797, 'Reg_MSE': 97259.2860010556, 'Class_CXE': 688254.3206155241}
Epoch: 26 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 779286.316754675, 'Reg_MSE': 91045.0640697062, 'Class_CXE': 688241.2526849688}
Epoch: 27 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 773471.1537429814, 'Reg_MSE': 85242.9716724163, 'Class_CXE': 688228.1820705651}
Epoch: 28 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 768040.4057476742, 'Reg_MSE': 79825.2934809345, 'Class_CXE': 688215.1122667397}
Epoch: 29 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 762968.2300685382, 'Reg_MSE': 74766.1879988546, 'Class_CXE': 688202.0420696837}
Epoch: 30 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 758230.5283141492, 'Reg_MSE': 70041.5606078706, 'Class_CXE': 688188.9677062787}
Epoch: 31 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 753804.8437255037, 'Reg_MSE': 65628.9452223132, 'Class_CXE': 688175.8985031905}
Epoch: 32 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 750044.5826278467, 'Reg_MSE': 61507.3939680788, 'Class_CXE': 688537.188659768}
Epoch: 33 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 746181.4723418057, 'Reg_MSE': 57657.3743416708, 'Class_CXE': 688524.0980001349}
Epoch: 34 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 742571.7056119966, 'Reg_MSE': 54060.6733420029, 'Class_CXE': 688511.0322699937}
Epoch: 35 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 739198.2767367075, 'Reg_MSE': 50700.3081020203, 'Class_CXE': 688497.9686346872}
Epoch: 36 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 736045.307604269, 'Reg_MSE': 47560.4425792882, 'Class_CXE': 688484.8650249809}
Epoch: 37 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 733098.0597354244, 'Reg_MSE': 44626.3098945964, 'Class_CXE': 688471.7498408279}
Epoch: 38 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 730342.8182680551, 'Reg_MSE': 41884.1399355059, 'Class_CXE': 688458.6783325492}
Epoch: 39 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 727766.7147087837, 'Reg_MSE': 39321.091867755, 'Class_CXE': 688445.6228410286}
Epoch: 40 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 725357.7354468008, 'Reg_MSE': 36925.191221658, 'Class_CXE': 688432.5442251428}
Epoch: 41 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 723104.7411862973, 'Reg_MSE': 34685.2712432145, 'Class_CXE': 688419.4699430828}
Epoch: 42 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 720997.3135557103, 'Reg_MSE': 32590.9182206955, 'Class_CXE': 688406.3953350148}
Epoch: 43 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 719025.7391487409, 'Reg_MSE': 30632.4205170877, 'Class_CXE': 688393.3186316532}
Epoch: 44 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 717180.9613429696, 'Reg_MSE': 28800.7210570714, 'Class_CXE': 688380.2402858982}
Epoch: 45 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 715454.5405158547, 'Reg_MSE': 27087.373034255, 'Class_CXE': 688367.1674815997}
Epoch: 46 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 713838.5900735632, 'Reg_MSE': 25484.4986202772, 'Class_CXE': 688354.091453286}
Epoch: 47 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 712325.7534018582, 'Reg_MSE': 23984.7504722071, 'Class_CXE': 688341.0029296511}
Epoch: 48 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 710909.2063874636, 'Reg_MSE': 22581.2758484761, 'Class_CXE': 688327.9305389875}
Epoch: 49 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 709582.5109189807, 'Reg_MSE': 21267.6831564517, 'Class_CXE': 688314.827762529}
Epoch: 50 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 708339.7186078951, 'Reg_MSE': 20038.0107667584, 'Class_CXE': 688301.7078411367}
Epoch: 51 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 707175.3306172858, 'Reg_MSE': 18886.6979406395, 'Class_CXE': 688288.6326766463}
Epoch: 52 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 706084.1182646916, 'Reg_MSE': 17808.5577270747, 'Class_CXE': 688275.5605376168}
Epoch: 53 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 705061.2362884348, 'Reg_MSE': 16798.7516960935, 'Class_CXE': 688262.4845923414}
Epoch: 54 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 704102.1777275894, 'Reg_MSE': 15852.766383777, 'Class_CXE': 688249.4113438124}
Epoch: 55 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 703202.7186962188, 'Reg_MSE': 14966.3913328911, 'Class_CXE': 688236.3273633277}
Epoch: 56 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 702358.8963075056, 'Reg_MSE': 14135.6986209633, 'Class_CXE': 688223.1976865423}
Epoch: 57 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 701567.15366697, 'Reg_MSE': 13357.0237749531, 'Class_CXE': 688210.1298920169}
Epoch: 58 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 700824.004426702, 'Reg_MSE': 12626.9479785093, 'Class_CXE': 688197.0564481927}
Epoch: 59 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 700126.257831117, 'Reg_MSE': 11942.2814841777, 'Class_CXE': 688183.9763469393}
Epoch: 60 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 699470.9288221385, 'Reg_MSE': 11300.048148874, 'Class_CXE': 688170.8806732645}
Epoch: 61 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 698855.2646163853, 'Reg_MSE': 10697.4710164706, 'Class_CXE': 688157.7935999148}
Epoch: 62 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 698276.6866428788, 'Reg_MSE': 10131.9588765141, 'Class_CXE': 688144.7277663647}
Epoch: 63 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 697732.720500232, 'Reg_MSE': 9601.0937329033, 'Class_CXE': 688131.6267673287}
Epoch: 64 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 697221.1712131229, 'Reg_MSE': 9102.6191208469, 'Class_CXE': 688118.552092276}
Epoch: 65 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 696739.8911978736, 'Reg_MSE': 8634.4292145998, 'Class_CXE': 688105.4619832739}
Epoch: 66 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 696286.9453007823, 'Reg_MSE': 8194.558672381, 'Class_CXE': 688092.3866284013}
Epoch: 67 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 695860.4846192389, 'Reg_MSE': 7781.1731685073, 'Class_CXE': 688079.3114507316}
Epoch: 68 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 695458.7921412352, 'Reg_MSE': 7392.5605661675, 'Class_CXE': 688066.2315750677}
Epoch: 69 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 695080.2777083508, 'Reg_MSE': 7027.1226874198, 'Class_CXE': 688053.155020931}
Epoch: 70 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 694723.446943282, 'Reg_MSE': 6683.3676399389, 'Class_CXE': 688040.0793033431}
Epoch: 71 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 694386.8966962802, 'Reg_MSE': 6359.9026627852, 'Class_CXE': 688026.994033495}
Epoch: 72 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 694069.355254285, 'Reg_MSE': 6055.4274560258, 'Class_CXE': 688013.9277982592}
Epoch: 73 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 693769.5672845452, 'Reg_MSE': 5768.7279614218, 'Class_CXE': 688000.8393231233}
Epoch: 74 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 693486.4174144024, 'Reg_MSE': 5498.6705636207, 'Class_CXE': 687987.7468507816}
Epoch: 75 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 693218.868532555, 'Reg_MSE': 5244.1966833638, 'Class_CXE': 687974.6718491912}
Epoch: 76 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 692965.8987931121, 'Reg_MSE': 5004.317736151, 'Class_CXE': 687961.5810569611}
Epoch: 77 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 692726.6514289399, 'Reg_MSE': 4778.1104316057, 'Class_CXE': 687948.5409973342}
Epoch: 78 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 692500.1818258993, 'Reg_MSE': 4564.7123904619, 'Class_CXE': 687935.4694354373}
Epoch: 79 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 692285.7102196985, 'Reg_MSE': 4363.3180576589, 'Class_CXE': 687922.3921620396}
Epoch: 80 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 692082.4754727103, 'Reg_MSE': 4173.1748914886, 'Class_CXE': 687909.3005812217}
Epoch: 81 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 691889.7676220107, 'Reg_MSE': 3993.5798101015, 'Class_CXE': 687896.1878119091}
Epoch: 82 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 691706.9987454412, 'Reg_MSE': 3823.8758779403, 'Class_CXE': 687883.1228675009}
Epoch: 83 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 691533.5199538629, 'Reg_MSE': 3663.4492158583, 'Class_CXE': 687870.0707380046}
Epoch: 84 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 691368.710106875, 'Reg_MSE': 3511.7261197745, 'Class_CXE': 687856.9839871005}
Epoch: 85 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 691212.077039544, 'Reg_MSE': 3368.1703737483, 'Class_CXE': 687843.9066657957}
Epoch: 86 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 691063.1151026819, 'Reg_MSE': 3232.2807443129, 'Class_CXE': 687830.834358369}
Epoch: 87 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690921.3387802457, 'Reg_MSE': 3103.5886437964, 'Class_CXE': 687817.7501364494}
Epoch: 88 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690786.3156919683, 'Reg_MSE': 2981.6559511948, 'Class_CXE': 687804.6597407735}
Epoch: 89 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690657.6591038642, 'Reg_MSE': 2866.0729799331, 'Class_CXE': 687791.5861239311}
Epoch: 90 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690534.9666654468, 'Reg_MSE': 2756.4565825761, 'Class_CXE': 687778.5100828706}
Epoch: 91 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690417.8877873917, 'Reg_MSE': 2652.4483832208, 'Class_CXE': 687765.4394041708}
Epoch: 92 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690306.0619960732, 'Reg_MSE': 2553.7131289345, 'Class_CXE': 687752.3488671387}
Epoch: 93 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690199.2188407166, 'Reg_MSE': 2459.9371521845, 'Class_CXE': 687739.2816885321}
Epoch: 94 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 690097.0227533873, 'Reg_MSE': 2370.8269367521, 'Class_CXE': 687726.1958166353}
Epoch: 95 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689999.2266061894, 'Reg_MSE': 2286.1077801334, 'Class_CXE': 687713.118826056}
Epoch: 96 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689905.5532156912, 'Reg_MSE': 2205.5225459025, 'Class_CXE': 687700.0306697887}
Epoch: 97 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689815.7874016832, 'Reg_MSE': 2128.8304999541, 'Class_CXE': 687686.9569017291}
Epoch: 98 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689729.6955740374, 'Reg_MSE': 2055.806224956, 'Class_CXE': 687673.8893490814}
Epoch: 99 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689647.0231907107, 'Reg_MSE': 1986.2386077261, 'Class_CXE': 687660.7845829846}
Epoch: 100 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689567.6420851992, 'Reg_MSE': 1919.9298946043, 'Class_CXE': 687647.712190595}
Epoch: 101 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689491.2630996503, 'Reg_MSE': 1856.6948102271, 'Class_CXE': 687634.5682894231}
Epoch: 102 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689417.856126167, 'Reg_MSE': 1796.3597354202, 'Class_CXE': 687621.4963907468}
Epoch: 103 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689347.1847957202, 'Reg_MSE': 1738.7619402176, 'Class_CXE': 687608.4228555026}
Epoch: 104 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689279.100012029, 'Reg_MSE': 1683.748868284, 'Class_CXE': 687595.351143745}
Epoch: 105 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689213.4554676801, 'Reg_MSE': 1631.1774692698, 'Class_CXE': 687582.2779984103}
Epoch: 106 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689150.1379875913, 'Reg_MSE': 1580.913575864, 'Class_CXE': 687569.2244117273}
Epoch: 107 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689088.9856186993, 'Reg_MSE': 1532.8313225284, 'Class_CXE': 687556.1542961709}
Epoch: 108 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689029.8798822629, 'Reg_MSE': 1486.8126031004, 'Class_CXE': 687543.0672791625}
Epoch: 109 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688972.7331277395, 'Reg_MSE': 1442.7465646431, 'Class_CXE': 687529.9865630964}
Epoch: 110 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688917.4529212967, 'Reg_MSE': 1400.5291350987, 'Class_CXE': 687516.923786198}
Epoch: 111 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688863.8839728262, 'Reg_MSE': 1360.0625824649, 'Class_CXE': 687503.8213903613}
Epoch: 112 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688811.9901882351, 'Reg_MSE': 1321.2551033715, 'Class_CXE': 687490.7350848636}
Epoch: 113 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688761.6830605407, 'Reg_MSE': 1284.0204390745, 'Class_CXE': 687477.6626214662}
Epoch: 114 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688713.0036310005, 'Reg_MSE': 1248.2775170221, 'Class_CXE': 687464.7261139784}
Epoch: 115 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688665.6084552985, 'Reg_MSE': 1213.9501162696, 'Class_CXE': 687451.6583390288}
Epoch: 116 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688619.5473148779, 'Reg_MSE': 1180.9665551388, 'Class_CXE': 687438.580759739}
Epoch: 117 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688574.7637065063, 'Reg_MSE': 1149.2593996238, 'Class_CXE': 687425.5043068825}
Epoch: 118 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688531.1947589673, 'Reg_MSE': 1118.7651911484, 'Class_CXE': 687412.4295678189}
Epoch: 119 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688488.7608143983, 'Reg_MSE': 1089.4241923738, 'Class_CXE': 687399.3366220245}
Epoch: 120 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688447.4149736025, 'Reg_MSE': 1061.1801498423, 'Class_CXE': 687386.2348237602}
Epoch: 121 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688407.1352278665, 'Reg_MSE': 1033.9800723267, 'Class_CXE': 687373.1551555397}
Epoch: 122 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688367.8756819074, 'Reg_MSE': 1007.7740238301, 'Class_CXE': 687360.1016580773}
Epoch: 123 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688329.5345089289, 'Reg_MSE': 982.5149302526, 'Class_CXE': 687347.0195786763}
Epoch: 124 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688292.0552631877, 'Reg_MSE': 958.1583988074, 'Class_CXE': 687333.8968643803}
Epoch: 125 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688255.4817492609, 'Reg_MSE': 934.662549332, 'Class_CXE': 687320.8191999289}
Epoch: 126 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688219.739639507, 'Reg_MSE': 911.9878566957, 'Class_CXE': 687307.7517828114}
Epoch: 127 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688184.7826423622, 'Reg_MSE': 890.0970035614, 'Class_CXE': 687294.6856388007}
Epoch: 128 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688524.9257310287, 'Reg_MSE': 868.9547428063, 'Class_CXE': 687655.9709882224}
Epoch: 129 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688491.4145009124, 'Reg_MSE': 848.5277689579, 'Class_CXE': 687642.8867319545}
Epoch: 130 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688458.6152474103, 'Reg_MSE': 828.7845980384, 'Class_CXE': 687629.8306493718}
Epoch: 131 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688426.4439325613, 'Reg_MSE': 809.6954552587, 'Class_CXE': 687616.7484773026}
Epoch: 132 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688769.2416836845, 'Reg_MSE': 791.2321700363, 'Class_CXE': 687978.0095136482}
Epoch: 133 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688738.2944248054, 'Reg_MSE': 773.3680778467, 'Class_CXE': 687964.9263469587}
Epoch: 134 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688707.8364397254, 'Reg_MSE': 756.0779284554, 'Class_CXE': 687951.7585112699}
Epoch: 135 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688677.9858840493, 'Reg_MSE': 739.3378001019, 'Class_CXE': 687938.6480839475}
Epoch: 136 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688648.6848957788, 'Reg_MSE': 723.1250192411, 'Class_CXE': 687925.5598765376}
Epoch: 137 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688619.8641383902, 'Reg_MSE': 707.4180854718, 'Class_CXE': 687912.4460529183}
Epoch: 138 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688591.552617438, 'Reg_MSE': 692.1966013065, 'Class_CXE': 687899.3560161315}
Epoch: 139 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688563.7058611929, 'Reg_MSE': 677.4412064618, 'Class_CXE': 687886.2646547311}
Epoch: 140 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688536.3051912569, 'Reg_MSE': 663.1335163679, 'Class_CXE': 687873.171674889}
Epoch: 141 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688509.3245022479, 'Reg_MSE': 649.256064619, 'Class_CXE': 687860.0684376288}
Epoch: 142 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688482.7738413778, 'Reg_MSE': 635.7922491012, 'Class_CXE': 687846.9815922766}
Epoch: 143 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688830.9889775687, 'Reg_MSE': 622.726281557, 'Class_CXE': 688208.2626960117}
Epoch: 144 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688805.2146846405, 'Reg_MSE': 610.0431403564, 'Class_CXE': 688195.1715442841}
Epoch: 145 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688779.8004846516, 'Reg_MSE': 597.7285262646, 'Class_CXE': 688182.071958387}
Epoch: 146 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688754.7481995549, 'Reg_MSE': 585.7688210081, 'Class_CXE': 688168.9793785467}
Epoch: 147 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688730.0342145372, 'Reg_MSE': 574.1510484541, 'Class_CXE': 688155.8831660831}
Epoch: 148 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688705.6368568854, 'Reg_MSE': 562.8628382314, 'Class_CXE': 688142.774018654}
Epoch: 149 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689055.9319355773, 'Reg_MSE': 551.8923916332, 'Class_CXE': 688504.0395439441}
Epoch: 150 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689032.162345506, 'Reg_MSE': 541.2284496504, 'Class_CXE': 688490.9338958556}
Epoch: 151 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689008.656341212, 'Reg_MSE': 530.8602629965, 'Class_CXE': 688477.7960782155}
Epoch: 152 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688985.4823696224, 'Reg_MSE': 520.7775639937, 'Class_CXE': 688464.7048056287}
Epoch: 153 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688962.5440347468, 'Reg_MSE': 510.9705401986, 'Class_CXE': 688451.5734945482}
Epoch: 154 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 688939.9017219571, 'Reg_MSE': 501.4298096529, 'Class_CXE': 688438.4719123043}
Epoch: 155 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689291.8788830448, 'Reg_MSE': 492.146397655, 'Class_CXE': 688799.7324853898}
Epoch: 156 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689269.7604911493, 'Reg_MSE': 483.1117149515, 'Class_CXE': 688786.6487761978}
Epoch: 157 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689247.8404127291, 'Reg_MSE': 474.3175372578, 'Class_CXE': 688773.5228754713}
Epoch: 158 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689226.1726190671, 'Reg_MSE': 465.7559860212, 'Class_CXE': 688760.4166330459}
Epoch: 159 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689204.7384769846, 'Reg_MSE': 457.4195103456, 'Class_CXE': 688747.318966639}
Epoch: 160 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689183.4799642569, 'Reg_MSE': 449.3008700037, 'Class_CXE': 688734.1790942532}
Epoch: 161 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689162.4501038014, 'Reg_MSE': 441.3931194656, 'Class_CXE': 688721.0569843359}
Epoch: 162 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689141.6320327548, 'Reg_MSE': 433.689592879, 'Class_CXE': 688707.9424398758}
Epoch: 163 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689121.0169308393, 'Reg_MSE': 426.1838899401, 'Class_CXE': 688694.8330408991}
Epoch: 164 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689100.5947537323, 'Reg_MSE': 418.8698625974, 'Class_CXE': 688681.7248911349}
Epoch: 165 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689080.3524824126, 'Reg_MSE': 411.7416025357, 'Class_CXE': 688668.6108798769}
Epoch: 166 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689060.3015654468, 'Reg_MSE': 404.7934293909, 'Class_CXE': 688655.5081360559}
Epoch: 167 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689040.4205286116, 'Reg_MSE': 398.0198796483, 'Class_CXE': 688642.4006489633}
Epoch: 168 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689020.6972843892, 'Reg_MSE': 391.415696182, 'Class_CXE': 688629.2815882071}
Epoch: 169 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689749.630770136, 'Reg_MSE': 384.9758183932, 'Class_CXE': 689364.6549517429}
Epoch: 170 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689730.2551275676, 'Reg_MSE': 378.6953729129, 'Class_CXE': 689351.5597546546}
Epoch: 171 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689711.0034404016, 'Reg_MSE': 372.569664829, 'Class_CXE': 689338.4337755727}
Epoch: 172 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689691.8900444916, 'Reg_MSE': 366.5941694093, 'Class_CXE': 689325.2958750824}
Epoch: 173 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689672.9403642985, 'Reg_MSE': 360.7645242866, 'Class_CXE': 689312.1758400119}
Epoch: 174 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689654.1283027712, 'Reg_MSE': 355.0765220778, 'Class_CXE': 689299.0517806935}
Epoch: 175 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689635.4521574591, 'Reg_MSE': 349.5261034112, 'Class_CXE': 689285.9260540479}
Epoch: 176 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689616.9109447481, 'Reg_MSE': 344.1093503348, 'Class_CXE': 689272.8015944132}
Epoch: 177 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689598.4968006837, 'Reg_MSE': 338.8224800837, 'Class_CXE': 689259.6743206}
Epoch: 178 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689580.1882087111, 'Reg_MSE': 333.6618391842, 'Class_CXE': 689246.5263695269}
Epoch: 179 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689561.9979599678, 'Reg_MSE': 328.6238978735, 'Class_CXE': 689233.3740620944}
Epoch: 180 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689918.3300972113, 'Reg_MSE': 323.705244816, 'Class_CXE': 689594.6248523953}
Epoch: 181 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689900.3821395758, 'Reg_MSE': 318.9025820988, 'Class_CXE': 689581.4795574769}
Epoch: 182 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689882.5497802027, 'Reg_MSE': 314.2127204886, 'Class_CXE': 689568.3370597141}
Epoch: 183 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689864.8372355945, 'Reg_MSE': 309.6325749355, 'Class_CXE': 689555.2046606591}
Epoch: 184 - {'Class_acc': 0.5258829639889196, 'CXE+MSE': 689847.2167652493, 'Reg_MSE': 305.159160307, 'Class_CXE': 689542.0576049422}
Epoch: 185 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689846.7000242879, 'Reg_MSE': 300.7895873408, 'Class_CXE': 689545.9104369471}
Epoch: 186 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689829.3033910353, 'Reg_MSE': 296.5210588021, 'Class_CXE': 689532.7823322332}
Epoch: 187 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689811.9967038092, 'Reg_MSE': 292.3508658328, 'Class_CXE': 689519.6458379765}
Epoch: 188 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689794.7850771486, 'Reg_MSE': 288.2763844832, 'Class_CXE': 689506.5086926654}
Epoch: 189 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689777.6591857172, 'Reg_MSE': 284.2950724138, 'Class_CXE': 689493.3641133035}
Epoch: 190 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689760.6319199635, 'Reg_MSE': 280.4044657581, 'Class_CXE': 689480.2274542054}
Epoch: 191 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689743.7018214306, 'Reg_MSE': 276.6021761372, 'Class_CXE': 689467.0996452934}
Epoch: 192 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689726.8538283836, 'Reg_MSE': 272.885887817, 'Class_CXE': 689453.9679405666}
Epoch: 193 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689710.0895333579, 'Reg_MSE': 269.2533550007, 'Class_CXE': 689440.8361783571}
Epoch: 194 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689693.4161113296, 'Reg_MSE': 265.7023992478, 'Class_CXE': 689427.7137120818}
Epoch: 195 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689676.7979132488, 'Reg_MSE': 262.2309070134, 'Class_CXE': 689414.5670062354}
Epoch: 196 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690034.6159659058, 'Reg_MSE': 258.8368273009, 'Class_CXE': 689775.7791386049}
Epoch: 197 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690018.1407212543, 'Reg_MSE': 255.5181694214, 'Class_CXE': 689762.6225518329}
Epoch: 198 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690001.752095683, 'Reg_MSE': 252.2730008545, 'Class_CXE': 689749.4790948285}
Epoch: 199 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689985.4303013341, 'Reg_MSE': 249.0994452039, 'Class_CXE': 689736.3308561302}
Epoch: 200 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689969.1892365923, 'Reg_MSE': 245.9956802448, 'Class_CXE': 689723.1935563475}
Epoch: 201 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689952.997277235, 'Reg_MSE': 242.9599360549, 'Class_CXE': 689710.0373411801}
Epoch: 202 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689936.902882166, 'Reg_MSE': 239.9904932278, 'Class_CXE': 689696.9123889381}
Epoch: 203 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689920.8421701265, 'Reg_MSE': 237.0856811622, 'Class_CXE': 689683.7564889643}
Epoch: 204 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689904.8627482438, 'Reg_MSE': 234.243876424, 'Class_CXE': 689670.6188718198}
Epoch: 205 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689888.9439793029, 'Reg_MSE': 231.4635011768, 'Class_CXE': 689657.4804781261}
Epoch: 206 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 689873.0911840735, 'Reg_MSE': 228.743021678, 'Class_CXE': 689644.3481623955}
Epoch: 207 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690231.6343979392, 'Reg_MSE': 226.0809468361, 'Class_CXE': 690005.553451103}
Epoch: 208 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690215.8842642034, 'Reg_MSE': 223.4758268278, 'Class_CXE': 689992.4084373756}
Epoch: 209 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690200.1924083221, 'Reg_MSE': 220.9262517693, 'Class_CXE': 689979.2661565528}
Epoch: 210 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690184.5297935376, 'Reg_MSE': 218.4308504416, 'Class_CXE': 689966.098943096}
Epoch: 211 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690168.9208609653, 'Reg_MSE': 215.9882890658, 'Class_CXE': 689952.9325718995}
Epoch: 212 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690527.7779382173, 'Reg_MSE': 213.5972701259, 'Class_CXE': 690314.1806680914}
Epoch: 213 - {'Class_acc': 0.5257963988919667, 'CXE+MSE': 690512.2836841003, 'Reg_MSE': 211.2565312376, 'Class_CXE': 690301.0271528626}
Epoch: 214 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690496.838601441, 'Reg_MSE': 208.9648440598, 'Class_CXE': 690287.8737573812}
Epoch: 215 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690481.4406332577, 'Reg_MSE': 206.7210132478, 'Class_CXE': 690274.7196200099}
Epoch: 216 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690466.0937228901, 'Reg_MSE': 204.5238754449, 'Class_CXE': 690261.5698474452}
Epoch: 217 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690450.7959317676, 'Reg_MSE': 202.3722983129, 'Class_CXE': 690248.4236334546}
Epoch: 218 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690435.5011504953, 'Reg_MSE': 200.2651795966, 'Class_CXE': 690235.2359708986}
Epoch: 219 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690420.2846475546, 'Reg_MSE': 198.2014462231, 'Class_CXE': 690222.0832013316}
Epoch: 220 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690405.1155343245, 'Reg_MSE': 196.1800534338, 'Class_CXE': 690208.9354808908}
Epoch: 221 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690389.9658160816, 'Reg_MSE': 194.199983947, 'Class_CXE': 690195.7658321346}
Epoch: 222 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690374.8748334767, 'Reg_MSE': 192.2602471498, 'Class_CXE': 690182.6145863269}
Epoch: 223 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690359.8399135098, 'Reg_MSE': 190.3598783192, 'Class_CXE': 690169.4800351906}
Epoch: 224 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690344.8185595698, 'Reg_MSE': 188.4979378694, 'Class_CXE': 690156.3206217004}
Epoch: 225 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690329.854324684, 'Reg_MSE': 186.6735106253, 'Class_CXE': 690143.1808140588}
Epoch: 226 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690314.8868206406, 'Reg_MSE': 184.8857051208, 'Class_CXE': 690130.0011155199}
Epoch: 227 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690299.9200672741, 'Reg_MSE': 183.133652921, 'Class_CXE': 690116.7864143532}
Epoch: 228 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690285.0668281809, 'Reg_MSE': 181.4165079665, 'Class_CXE': 690103.6503202145}
Epoch: 229 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690270.2224584952, 'Reg_MSE': 179.7334459401, 'Class_CXE': 690090.4890125551}
Epoch: 230 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690255.4191046503, 'Reg_MSE': 178.0836636542, 'Class_CXE': 690077.3354409961}
Epoch: 231 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690240.6471861883, 'Reg_MSE': 176.4663784578, 'Class_CXE': 690064.1808077304}
Epoch: 232 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690225.9132075058, 'Reg_MSE': 174.8808276638, 'Class_CXE': 690051.032379842}
Epoch: 233 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690211.2064277766, 'Reg_MSE': 173.326267993, 'Class_CXE': 690037.8801597836}
Epoch: 234 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690196.5299457684, 'Reg_MSE': 171.8019750378, 'Class_CXE': 690024.7279707306}
Epoch: 235 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690181.8631315358, 'Reg_MSE': 170.3072427412, 'Class_CXE': 690011.5558887946}
Epoch: 236 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690167.246159788, 'Reg_MSE': 168.8413828934, 'Class_CXE': 689998.4047768946}
Epoch: 237 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690152.6470985563, 'Reg_MSE': 167.4037246437, 'Class_CXE': 689985.2433739125}
Epoch: 238 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690138.0787357795, 'Reg_MSE': 165.9936140268, 'Class_CXE': 689972.0851217527}
Epoch: 239 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690123.5187732996, 'Reg_MSE': 164.6104135054, 'Class_CXE': 689958.9083597943}
Epoch: 240 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690108.9515381088, 'Reg_MSE': 163.2535015248, 'Class_CXE': 689945.698036584}
Epoch: 241 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690094.4700819991, 'Reg_MSE': 161.9222720825, 'Class_CXE': 689932.5478099165}
Epoch: 242 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690080.0443094677, 'Reg_MSE': 160.6161343105, 'Class_CXE': 689919.4281751572}
Epoch: 243 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690065.609276252, 'Reg_MSE': 159.3345120692, 'Class_CXE': 689906.2747641827}
Epoch: 244 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690051.2018022682, 'Reg_MSE': 158.076843555, 'Class_CXE': 689893.1249587132}
Epoch: 245 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690036.8159937797, 'Reg_MSE': 156.8425809181, 'Class_CXE': 689879.9734128616}
Epoch: 246 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690022.4260739984, 'Reg_MSE': 155.6311898923, 'Class_CXE': 689866.7948841061}
Epoch: 247 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690008.0725023216, 'Reg_MSE': 154.4421494357, 'Class_CXE': 689853.6303528858}
Epoch: 248 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 689993.7468584134, 'Reg_MSE': 153.2749513816, 'Class_CXE': 689840.4719070318}
Epoch: 249 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 689979.4495977749, 'Reg_MSE': 152.1291000994, 'Class_CXE': 689827.3204976754}
Epoch: 250 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 689965.1609098496, 'Reg_MSE': 151.0041121658, 'Class_CXE': 689814.1567976838}
Epoch: 251 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 689950.9205283719, 'Reg_MSE': 149.8995160451, 'Class_CXE': 689801.0210123268}
Epoch: 252 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 689936.6832935737, 'Reg_MSE': 148.8148517787, 'Class_CXE': 689787.868441795}
Epoch: 253 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 689922.4393581196, 'Reg_MSE': 147.7496706833, 'Class_CXE': 689774.6896874363}
Epoch: 254 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690282.5887554968, 'Reg_MSE': 146.7035350579, 'Class_CXE': 690135.8852204388}
Epoch: 255 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690268.3909957242, 'Reg_MSE': 145.6760178988, 'Class_CXE': 690122.7149778254}
Epoch: 256 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690254.2203739332, 'Reg_MSE': 144.666702623, 'Class_CXE': 690109.5536713102}
Epoch: 257 - {'Class_acc': 0.5257098337950139, 'CXE+MSE': 690614.4028674102, 'Reg_MSE': 143.6751827989, 'Class_CXE': 690470.7276846113}
Epoch: 258 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690600.2590128107, 'Reg_MSE': 142.7010618845, 'Class_CXE': 690457.5579509261}
Epoch: 259 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690586.1237990961, 'Reg_MSE': 141.7439529735, 'Class_CXE': 690444.3798461226}
Epoch: 260 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690572.0069676978, 'Reg_MSE': 140.8034785474, 'Class_CXE': 690431.2034891504}
Epoch: 261 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690557.9140078117, 'Reg_MSE': 139.8792702354, 'Class_CXE': 690418.0347375763}
Epoch: 262 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690543.8452282764, 'Reg_MSE': 138.9709685803, 'Class_CXE': 690404.8742596961}
Epoch: 263 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690529.7826608111, 'Reg_MSE': 138.0782228109, 'Class_CXE': 690391.7044380002}
Epoch: 264 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690515.7262552719, 'Reg_MSE': 137.2006906207, 'Class_CXE': 690378.5255646512}
Epoch: 265 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690501.6988168319, 'Reg_MSE': 136.3380379525, 'Class_CXE': 690365.3607788794}
Epoch: 266 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690487.6842978614, 'Reg_MSE': 135.4899387892, 'Class_CXE': 690352.1943590723}
Epoch: 267 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690473.6772251221, 'Reg_MSE': 134.6560749495, 'Class_CXE': 690339.0211501726}
Epoch: 268 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690459.6637135607, 'Reg_MSE': 133.8361358899, 'Class_CXE': 690325.8275776708}
Epoch: 269 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690445.6982621913, 'Reg_MSE': 133.0298185115, 'Class_CXE': 690312.6684436798}
Epoch: 270 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690431.7287603489, 'Reg_MSE': 132.236826972, 'Class_CXE': 690299.4919333769}
Epoch: 271 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690417.7778858489, 'Reg_MSE': 131.4568725032, 'Class_CXE': 690286.3210133457}
Epoch: 272 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690403.8441702299, 'Reg_MSE': 130.6896732326, 'Class_CXE': 690273.1544969972}
Epoch: 273 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690389.9072289044, 'Reg_MSE': 129.9349540102, 'Class_CXE': 690259.9722748942}
Epoch: 274 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690376.0258875855, 'Reg_MSE': 129.1924462403, 'Class_CXE': 690246.8334413451}
Epoch: 275 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690362.1190726974, 'Reg_MSE': 128.4618877163, 'Class_CXE': 690233.657184981}
Epoch: 276 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690348.2248028906, 'Reg_MSE': 127.743022462, 'Class_CXE': 690220.4817804286}
Epoch: 277 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690334.3466799561, 'Reg_MSE': 127.0356005747, 'Class_CXE': 690207.3110793814}
Epoch: 278 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690320.4857797553, 'Reg_MSE': 126.3393780747, 'Class_CXE': 690194.1464016805}
Epoch: 279 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690680.9769634231, 'Reg_MSE': 125.6541167566, 'Class_CXE': 690555.3228466664}
Epoch: 280 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690667.1724816967, 'Reg_MSE': 124.9795840465, 'Class_CXE': 690542.1928976502}
Epoch: 281 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690653.3287503769, 'Reg_MSE': 124.3155528611, 'Class_CXE': 690529.0131975159}
Epoch: 282 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690639.4925206797, 'Reg_MSE': 123.661801472, 'Class_CXE': 690515.8307192078}
Epoch: 283 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690625.676116453, 'Reg_MSE': 123.0181133724, 'Class_CXE': 690502.6580030806}
Epoch: 284 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690611.8705768965, 'Reg_MSE': 122.3842771479, 'Class_CXE': 690489.4862997485}
Epoch: 285 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690598.062264262, 'Reg_MSE': 121.7600863504, 'Class_CXE': 690476.3021779116}
Epoch: 286 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690584.2783606541, 'Reg_MSE': 121.1453393752, 'Class_CXE': 690463.1330212789}
Epoch: 287 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690570.4943429766, 'Reg_MSE': 120.5398393416, 'Class_CXE': 690449.954503635}
Epoch: 288 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690556.7226155058, 'Reg_MSE': 119.9433939762, 'Class_CXE': 690436.7792215296}
Epoch: 289 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690542.9441006962, 'Reg_MSE': 119.3558154993, 'Class_CXE': 690423.5882851969}
Epoch: 290 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690529.1898928488, 'Reg_MSE': 118.7769205144, 'Class_CXE': 690410.4129723344}
Epoch: 291 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690515.3991585748, 'Reg_MSE': 118.2065299001, 'Class_CXE': 690397.1926286747}
Epoch: 292 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690501.5813932365, 'Reg_MSE': 117.6444687056, 'Class_CXE': 690383.9369245309}
Epoch: 293 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690487.8536562026, 'Reg_MSE': 117.0905660478, 'Class_CXE': 690370.7630901547}
Epoch: 294 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690474.1298138241, 'Reg_MSE': 116.5446550118, 'Class_CXE': 690357.5851588122}
Epoch: 295 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690460.4142919937, 'Reg_MSE': 116.0065725538, 'Class_CXE': 690344.40771944}
Epoch: 296 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690446.7102937951, 'Reg_MSE': 115.4761594061, 'Class_CXE': 690331.234134389}
Epoch: 297 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690433.0034065355, 'Reg_MSE': 114.9532599849, 'Class_CXE': 690318.0501465506}
Epoch: 298 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690419.3001343795, 'Reg_MSE': 114.4377223004, 'Class_CXE': 690304.8624120791}
Epoch: 299 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690405.6152399406, 'Reg_MSE': 113.9293978691, 'Class_CXE': 690291.6858420714}
Epoch: 300 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690391.9393262482, 'Reg_MSE': 113.4281416281, 'Class_CXE': 690278.5111846202}
Epoch: 301 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690378.2733164746, 'Reg_MSE': 112.9338118522, 'Class_CXE': 690265.3395046224}
Epoch: 302 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690364.6056684671, 'Reg_MSE': 112.4462700722, 'Class_CXE': 690252.1593983949}
Epoch: 303 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690350.9502909565, 'Reg_MSE': 111.9653809962, 'Class_CXE': 690238.9849099603}
Epoch: 304 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690337.3030955298, 'Reg_MSE': 111.491012432, 'Class_CXE': 690225.8120830978}
Epoch: 305 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690323.6639135099, 'Reg_MSE': 111.0230352122, 'Class_CXE': 690212.6408782977}
Epoch: 306 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690684.3841315527, 'Reg_MSE': 110.561323121, 'Class_CXE': 690573.8228084317}
Epoch: 307 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690670.7291203712, 'Reg_MSE': 110.1057528221, 'Class_CXE': 690560.623367549}
Epoch: 308 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690657.1233316072, 'Reg_MSE': 109.6562037896, 'Class_CXE': 690547.4671278176}
Epoch: 309 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690643.4842541436, 'Reg_MSE': 109.2125582401, 'Class_CXE': 690534.2716959035}
Epoch: 310 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690629.814305851, 'Reg_MSE': 108.7747010658, 'Class_CXE': 690521.0396047852}
Epoch: 311 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690616.2047378288, 'Reg_MSE': 108.3425197706, 'Class_CXE': 690507.8622180582}
Epoch: 312 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690602.5983896576, 'Reg_MSE': 107.9159044067, 'Class_CXE': 690494.6824852509}
Epoch: 313 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690963.3637901563, 'Reg_MSE': 107.4947475137, 'Class_CXE': 690855.8690426426}
Epoch: 314 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690949.7486949814, 'Reg_MSE': 107.0789440581, 'Class_CXE': 690842.6697509233}
Epoch: 315 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690936.1464210652, 'Reg_MSE': 106.6683913754, 'Class_CXE': 690829.4780296898}
Epoch: 316 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690922.5318365862, 'Reg_MSE': 106.2629891129, 'Class_CXE': 690816.2688474733}
Epoch: 317 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690908.9381317535, 'Reg_MSE': 105.8626391746, 'Class_CXE': 690803.0754925789}
Epoch: 318 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690895.3500649198, 'Reg_MSE': 105.4672456667, 'Class_CXE': 690789.8828192531}
Epoch: 319 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690881.7843944362, 'Reg_MSE': 105.0767148447, 'Class_CXE': 690776.7076795915}
Epoch: 320 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690868.2048951872, 'Reg_MSE': 104.6909550627, 'Class_CXE': 690763.5139401244}
Epoch: 321 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690854.6277489124, 'Reg_MSE': 104.3098767222, 'Class_CXE': 690750.3178721903}
Epoch: 322 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690841.0598714944, 'Reg_MSE': 103.9333922241, 'Class_CXE': 690737.1264792703}
Epoch: 323 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 690827.4766852584, 'Reg_MSE': 103.5614159202, 'Class_CXE': 690723.9152693383}
Epoch: 324 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690813.9076148706, 'Reg_MSE': 103.193864067, 'Class_CXE': 690710.7137508036}
Epoch: 325 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690800.3191056724, 'Reg_MSE': 102.8306547807, 'Class_CXE': 690697.4884508917}
Epoch: 326 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690786.769150063, 'Reg_MSE': 102.471707992, 'Class_CXE': 690684.297442071}
Epoch: 327 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690773.2266954267, 'Reg_MSE': 102.1169454037, 'Class_CXE': 690671.109750023}
Epoch: 328 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690759.6830240834, 'Reg_MSE': 101.7662904483, 'Class_CXE': 690657.9167336351}
Epoch: 329 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690746.1345963277, 'Reg_MSE': 101.4196682469, 'Class_CXE': 690644.7149280808}
Epoch: 330 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690732.6035414833, 'Reg_MSE': 101.0770055692, 'Class_CXE': 690631.526535914}
Epoch: 331 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690719.0922501936, 'Reg_MSE': 100.7382307944, 'Class_CXE': 690618.3540193992}
Epoch: 332 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690705.5722093813, 'Reg_MSE': 100.403273873, 'Class_CXE': 690605.1689355082}
Epoch: 333 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690692.035533218, 'Reg_MSE': 100.0720662899, 'Class_CXE': 690591.963466928}
Epoch: 334 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690678.5061295655, 'Reg_MSE': 99.7445410275, 'Class_CXE': 690578.7615885381}
Epoch: 335 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690664.9898709839, 'Reg_MSE': 99.4206325312, 'Class_CXE': 690565.5692384527}
Epoch: 336 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690651.4721964587, 'Reg_MSE': 99.1002766743, 'Class_CXE': 690552.3719197845}
Epoch: 337 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690637.9614134253, 'Reg_MSE': 98.7834107242, 'Class_CXE': 690539.178002701}
Epoch: 338 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690624.4606051289, 'Reg_MSE': 98.4699733102, 'Class_CXE': 690525.9906318187}
Epoch: 339 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690610.9169941695, 'Reg_MSE': 98.1599043909, 'Class_CXE': 690512.7570897785}
Epoch: 340 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690597.417072338, 'Reg_MSE': 97.8531452231, 'Class_CXE': 690499.5639271149}
Epoch: 341 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690583.8993679719, 'Reg_MSE': 97.5496383315, 'Class_CXE': 690486.3497296404}
Epoch: 342 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690570.4213465095, 'Reg_MSE': 97.2493274786, 'Class_CXE': 690473.1720190309}
Epoch: 343 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690556.9381233163, 'Reg_MSE': 96.952157636, 'Class_CXE': 690459.9859656802}
Epoch: 344 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690543.489861834, 'Reg_MSE': 96.658074956, 'Class_CXE': 690446.8317868779}
Epoch: 345 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690530.0474959058, 'Reg_MSE': 96.3670267438, 'Class_CXE': 690433.680469162}
Epoch: 346 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690516.5736123485, 'Reg_MSE': 96.0789614309, 'Class_CXE': 690420.4946509176}
Epoch: 347 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690503.0946878806, 'Reg_MSE': 95.7938285483, 'Class_CXE': 690407.3008593323}
Epoch: 348 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690489.6211945573, 'Reg_MSE': 95.5115787015, 'Class_CXE': 690394.1096158557}
Epoch: 349 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690476.1340638796, 'Reg_MSE': 95.2321635448, 'Class_CXE': 690380.9019003349}
Epoch: 350 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690462.6669944378, 'Reg_MSE': 94.9555357573, 'Class_CXE': 690367.7114586805}
Epoch: 351 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690449.1892785612, 'Reg_MSE': 94.6816490191, 'Class_CXE': 690354.507629542}
Epoch: 352 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690435.7373478244, 'Reg_MSE': 94.4104579876, 'Class_CXE': 690341.3268898368}
Epoch: 353 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690422.2773841223, 'Reg_MSE': 94.1419182755, 'Class_CXE': 690328.1354658467}
Epoch: 354 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690408.832784563, 'Reg_MSE': 93.8759864279, 'Class_CXE': 690314.9567981351}
Epoch: 355 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690395.3908777366, 'Reg_MSE': 93.6126199012, 'Class_CXE': 690301.7782578354}
Epoch: 356 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690381.9269334361, 'Reg_MSE': 93.3517770421, 'Class_CXE': 690288.5751563939}
Epoch: 357 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690368.4554289277, 'Reg_MSE': 93.0934170665, 'Class_CXE': 690275.3620118612}
Epoch: 358 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690354.9981542432, 'Reg_MSE': 92.8375000399, 'Class_CXE': 690262.1606542033}
Epoch: 359 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690341.5543071718, 'Reg_MSE': 92.5839868576, 'Class_CXE': 690248.9703203143}
Epoch: 360 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690328.104756149, 'Reg_MSE': 92.3328392257, 'Class_CXE': 690235.7719169233}
Epoch: 361 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690314.6675689454, 'Reg_MSE': 92.0840196423, 'Class_CXE': 690222.5835493031}
Epoch: 362 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690301.2373638939, 'Reg_MSE': 91.8374913793, 'Class_CXE': 690209.3998725145}
Epoch: 363 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690287.7967544561, 'Reg_MSE': 91.5932184648, 'Class_CXE': 690196.2035359914}
Epoch: 364 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690274.3579728019, 'Reg_MSE': 91.3511656657, 'Class_CXE': 690183.0068071362}
Epoch: 365 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690260.9165363593, 'Reg_MSE': 91.1112984705, 'Class_CXE': 690169.8052378888}
Epoch: 366 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690247.5147341266, 'Reg_MSE': 90.8735830733, 'Class_CXE': 690156.6411510534}
Epoch: 367 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690234.0749036755, 'Reg_MSE': 90.637986357, 'Class_CXE': 690143.4369173185}
Epoch: 368 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690220.6558475479, 'Reg_MSE': 90.4044758783, 'Class_CXE': 690130.2513716696}
Epoch: 369 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690207.2382684831, 'Reg_MSE': 90.1730198518, 'Class_CXE': 690117.0652486313}
Epoch: 370 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690193.8236791226, 'Reg_MSE': 89.9435871352, 'Class_CXE': 690103.8800919874}
Epoch: 371 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690180.4009732745, 'Reg_MSE': 89.7161472147, 'Class_CXE': 690090.6848260597}
Epoch: 372 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690166.9520992152, 'Reg_MSE': 89.4906701905, 'Class_CXE': 690077.4614290247}
Epoch: 373 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690153.5365753882, 'Reg_MSE': 89.2671267634, 'Class_CXE': 690064.2694486248}
Epoch: 374 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690140.1266350762, 'Reg_MSE': 89.0454882205, 'Class_CXE': 690051.0811468557}
Epoch: 375 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690126.7096256763, 'Reg_MSE': 88.8257264222, 'Class_CXE': 690037.883899254}
Epoch: 376 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690113.1050951876, 'Reg_MSE': 88.6078137896, 'Class_CXE': 690024.4972813979}
Epoch: 377 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690099.7092148779, 'Reg_MSE': 88.3917232912, 'Class_CXE': 690011.3174915867}
Epoch: 378 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690086.3080624708, 'Reg_MSE': 88.1774284309, 'Class_CXE': 689998.1306340399}
Epoch: 379 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690072.9050620417, 'Reg_MSE': 87.9649032361, 'Class_CXE': 689984.9401588056}
Epoch: 380 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690059.509165671, 'Reg_MSE': 87.7541222454, 'Class_CXE': 689971.7550434256}
Epoch: 381 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690046.1109367624, 'Reg_MSE': 87.5450604977, 'Class_CXE': 689958.5658762647}
Epoch: 382 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690032.6312753746, 'Reg_MSE': 87.3376935205, 'Class_CXE': 689945.2935818541}
Epoch: 383 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690393.595521769, 'Reg_MSE': 87.1319973192, 'Class_CXE': 690306.4635244498}
Epoch: 384 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690380.1985714601, 'Reg_MSE': 86.9279483664, 'Class_CXE': 690293.2706230937}
Epoch: 385 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690366.8268762961, 'Reg_MSE': 86.7255235913, 'Class_CXE': 690280.1013527048}
Epoch: 386 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690353.4039846855, 'Reg_MSE': 86.5247003697, 'Class_CXE': 690266.8792843159}
Epoch: 387 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690340.0081849421, 'Reg_MSE': 86.3254565137, 'Class_CXE': 690253.6827284284}
Epoch: 388 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690326.6211745732, 'Reg_MSE': 86.1277702627, 'Class_CXE': 690240.4934043105}
Epoch: 389 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690313.2001942187, 'Reg_MSE': 85.9316202729, 'Class_CXE': 690227.2685739457}
Epoch: 390 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690299.8140479055, 'Reg_MSE': 85.7369856088, 'Class_CXE': 690214.0770622967}
Epoch: 391 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690286.4167619199, 'Reg_MSE': 85.543845734, 'Class_CXE': 690200.8729161859}
Epoch: 392 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690273.0254245937, 'Reg_MSE': 85.3521805022, 'Class_CXE': 690187.6732440916}
Epoch: 393 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690259.6404528697, 'Reg_MSE': 85.1619701487, 'Class_CXE': 690174.478482721}
Epoch: 394 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690620.6271454918, 'Reg_MSE': 84.9731952817, 'Class_CXE': 690535.6539502101}
Epoch: 395 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690607.240781938, 'Reg_MSE': 84.7858368746, 'Class_CXE': 690522.4549450633}
Epoch: 396 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690593.851909678, 'Reg_MSE': 84.5998762575, 'Class_CXE': 690509.2520334205}
Epoch: 397 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690580.455587911, 'Reg_MSE': 84.4152951097, 'Class_CXE': 690496.0402928012}
Epoch: 398 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690567.0635019607, 'Reg_MSE': 84.2320754517, 'Class_CXE': 690482.831426509}
Epoch: 399 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690553.6668657043, 'Reg_MSE': 84.0501996379, 'Class_CXE': 690469.6166660665}
Epoch: 400 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690540.2309110934, 'Reg_MSE': 83.8696503493, 'Class_CXE': 690456.3612607442}
Epoch: 401 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690526.8208911797, 'Reg_MSE': 83.6904105864, 'Class_CXE': 690443.1304805933}
Epoch: 402 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690513.4406012207, 'Reg_MSE': 83.5124636622, 'Class_CXE': 690429.9281375585}
Epoch: 403 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690500.0654360546, 'Reg_MSE': 83.3357931953, 'Class_CXE': 690416.7296428593}
Epoch: 404 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690486.6903482766, 'Reg_MSE': 83.1603831033, 'Class_CXE': 690403.5299651733}
Epoch: 405 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690473.2966640259, 'Reg_MSE': 82.9862175964, 'Class_CXE': 690390.3104464294}
Epoch: 406 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690459.9075243912, 'Reg_MSE': 82.8132811711, 'Class_CXE': 690377.09424322}
Epoch: 407 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690446.5665623137, 'Reg_MSE': 82.6415586036, 'Class_CXE': 690363.9250037101}
Epoch: 408 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690433.191682891, 'Reg_MSE': 82.4710349444, 'Class_CXE': 690350.7206479466}
Epoch: 409 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690419.8090207005, 'Reg_MSE': 82.3016955118, 'Class_CXE': 690337.5073251887}
Epoch: 410 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690406.4334927037, 'Reg_MSE': 82.1335258866, 'Class_CXE': 690324.2999668171}
Epoch: 411 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690393.0747181752, 'Reg_MSE': 81.9665119061, 'Class_CXE': 690311.1082062691}
Epoch: 412 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690379.6881414325, 'Reg_MSE': 81.8006396586, 'Class_CXE': 690297.8875017739}
Epoch: 413 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690366.310616788, 'Reg_MSE': 81.6358954785, 'Class_CXE': 690284.6747213094}
Epoch: 414 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690353.0036901834, 'Reg_MSE': 81.4722659404, 'Class_CXE': 690271.531424243}
Epoch: 415 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690339.6290422321, 'Reg_MSE': 81.3097378545, 'Class_CXE': 690258.3193043775}
Epoch: 416 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690326.2643396801, 'Reg_MSE': 81.1482982612, 'Class_CXE': 690245.1160414189}
Epoch: 417 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690312.8975042746, 'Reg_MSE': 80.9879344265, 'Class_CXE': 690231.9095698481}
Epoch: 418 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690299.5244513055, 'Reg_MSE': 80.8286338368, 'Class_CXE': 690218.6958174687}
Epoch: 419 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690286.1568716189, 'Reg_MSE': 80.6703841949, 'Class_CXE': 690205.486487424}
Epoch: 420 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690272.7938242898, 'Reg_MSE': 80.5131734147, 'Class_CXE': 690192.2806508751}
Epoch: 421 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690259.4314137506, 'Reg_MSE': 80.3569896173, 'Class_CXE': 690179.0744241333}
Epoch: 422 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690246.0989829771, 'Reg_MSE': 80.2018211262, 'Class_CXE': 690165.8971618509}
Epoch: 423 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690607.0794112169, 'Reg_MSE': 80.0476564635, 'Class_CXE': 690527.0317547534}
Epoch: 424 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690593.7034890872, 'Reg_MSE': 79.8944843452, 'Class_CXE': 690513.809004742}
Epoch: 425 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690580.3377666632, 'Reg_MSE': 79.7422936778, 'Class_CXE': 690500.5954729854}
Epoch: 426 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690566.9780460207, 'Reg_MSE': 79.5910735535, 'Class_CXE': 690487.3869724672}
Epoch: 427 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690553.6080831549, 'Reg_MSE': 79.4408132471, 'Class_CXE': 690474.1672699078}
Epoch: 428 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690540.288719713, 'Reg_MSE': 79.2915022118, 'Class_CXE': 690460.9972175012}
Epoch: 429 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690526.9148042565, 'Reg_MSE': 79.1431300755, 'Class_CXE': 690447.771674181}
Epoch: 430 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690513.5464023672, 'Reg_MSE': 78.9956866374, 'Class_CXE': 690434.5507157298}
Epoch: 431 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690500.178617986, 'Reg_MSE': 78.849161864, 'Class_CXE': 690421.329456122}
Epoch: 432 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690486.8311400305, 'Reg_MSE': 78.7035458864, 'Class_CXE': 690408.1275941441}
Epoch: 433 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690473.4710805404, 'Reg_MSE': 78.5588289959, 'Class_CXE': 690394.9122515445}
Epoch: 434 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690460.1106614367, 'Reg_MSE': 78.4150016418, 'Class_CXE': 690381.6956597948}
Epoch: 435 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690446.7558247077, 'Reg_MSE': 78.2720544273, 'Class_CXE': 690368.4837702804}
Epoch: 436 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690433.3902286219, 'Reg_MSE': 78.1299781067, 'Class_CXE': 690355.2602505152}
Epoch: 437 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690420.0622603432, 'Reg_MSE': 77.9887635825, 'Class_CXE': 690342.0734967607}
Epoch: 438 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690406.7067481751, 'Reg_MSE': 77.848401902, 'Class_CXE': 690328.8583462731}
Epoch: 439 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690393.35889733, 'Reg_MSE': 77.7088842545, 'Class_CXE': 690315.6500130756}
Epoch: 440 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690380.0115839767, 'Reg_MSE': 77.5702019685, 'Class_CXE': 690302.4413820081}
Epoch: 441 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690366.6562644824, 'Reg_MSE': 77.4323465089, 'Class_CXE': 690289.2239179735}
Epoch: 442 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690353.3017927114, 'Reg_MSE': 77.295309474, 'Class_CXE': 690276.0064832374}
Epoch: 443 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690339.9366552744, 'Reg_MSE': 77.1590825933, 'Class_CXE': 690262.7775726811}
Epoch: 444 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690700.9545002584, 'Reg_MSE': 77.0236577243, 'Class_CXE': 690623.9308425342}
Epoch: 445 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690687.5898338866, 'Reg_MSE': 76.8890268502, 'Class_CXE': 690610.7008070364}
Epoch: 446 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690674.2491478074, 'Reg_MSE': 76.7551820774, 'Class_CXE': 690597.49396573}
Epoch: 447 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690660.8901571435, 'Reg_MSE': 76.6221156332, 'Class_CXE': 690584.2680415103}
Epoch: 448 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690647.4877036719, 'Reg_MSE': 76.4898198631, 'Class_CXE': 690570.9978838088}
Epoch: 449 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690634.1299825563, 'Reg_MSE': 76.3582872284, 'Class_CXE': 690557.771695328}
Epoch: 450 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690620.791051476, 'Reg_MSE': 76.2275103042, 'Class_CXE': 690544.5635411718}
Epoch: 451 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690607.4471009305, 'Reg_MSE': 76.0974817773, 'Class_CXE': 690531.3496191532}
Epoch: 452 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690594.0871299875, 'Reg_MSE': 75.9681944433, 'Class_CXE': 690518.1189355443}
Epoch: 453 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690580.7404342282, 'Reg_MSE': 75.8396412051, 'Class_CXE': 690504.9007930231}
Epoch: 454 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690567.3997179568, 'Reg_MSE': 75.7118150706, 'Class_CXE': 690491.6879028862}
Epoch: 455 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690554.0381614862, 'Reg_MSE': 75.5847091506, 'Class_CXE': 690478.4534523356}
Epoch: 456 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690540.6862629105, 'Reg_MSE': 75.4583166569, 'Class_CXE': 690465.2279462536}
Epoch: 457 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690527.3378346638, 'Reg_MSE': 75.3326309002, 'Class_CXE': 690452.0052037636}
Epoch: 458 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690513.9918420409, 'Reg_MSE': 75.2076452882, 'Class_CXE': 690438.7841967527}
Epoch: 459 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690500.6501862298, 'Reg_MSE': 75.0833533239, 'Class_CXE': 690425.566832906}
Epoch: 460 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690487.3112350588, 'Reg_MSE': 74.9597486037, 'Class_CXE': 690412.3514864551}
Epoch: 461 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690473.9390807853, 'Reg_MSE': 74.8368248153, 'Class_CXE': 690399.10225597}
Epoch: 462 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690460.5887841459, 'Reg_MSE': 74.7145757365, 'Class_CXE': 690385.8742084093}
Epoch: 463 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690447.2424454112, 'Reg_MSE': 74.5929952329, 'Class_CXE': 690372.6494501783}
Epoch: 464 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690433.9110660134, 'Reg_MSE': 74.4720772566, 'Class_CXE': 690359.4389887567}
Epoch: 465 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690420.575343208, 'Reg_MSE': 74.3518158445, 'Class_CXE': 690346.2235273635}
Epoch: 466 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690407.2297557346, 'Reg_MSE': 74.2322051167, 'Class_CXE': 690332.9975506179}
Epoch: 467 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690393.8887529912, 'Reg_MSE': 74.1132392748, 'Class_CXE': 690319.7755137164}
Epoch: 468 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690380.5333258958, 'Reg_MSE': 73.9949126004, 'Class_CXE': 690306.5384132954}
Epoch: 469 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690367.2036816286, 'Reg_MSE': 73.8772194537, 'Class_CXE': 690293.3264621749}
Epoch: 470 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690353.8927393969, 'Reg_MSE': 73.760154272, 'Class_CXE': 690280.1325851249}
Epoch: 471 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690340.5463334715, 'Reg_MSE': 73.6437115682, 'Class_CXE': 690266.9026219033}
Epoch: 472 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690327.2088990435, 'Reg_MSE': 73.5278859294, 'Class_CXE': 690253.681013114}
Epoch: 473 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690313.8757115966, 'Reg_MSE': 73.4126720155, 'Class_CXE': 690240.4630395811}
Epoch: 474 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690300.5515296808, 'Reg_MSE': 73.298064558, 'Class_CXE': 690227.2534651229}
Epoch: 475 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690287.2555786484, 'Reg_MSE': 73.1840583587, 'Class_CXE': 690214.0715202898}
Epoch: 476 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690273.9265808483, 'Reg_MSE': 73.070648288, 'Class_CXE': 690200.8559325603}
Epoch: 477 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690260.5973618504, 'Reg_MSE': 72.957829284, 'Class_CXE': 690187.6395325664}
Epoch: 478 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690247.2512283922, 'Reg_MSE': 72.8455963514, 'Class_CXE': 690174.4056320408}
Epoch: 479 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690233.924571321, 'Reg_MSE': 72.7339445599, 'Class_CXE': 690161.1906267611}
Epoch: 480 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690220.5983007937, 'Reg_MSE': 72.6228690433, 'Class_CXE': 690147.9754317504}
Epoch: 481 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690207.2681383735, 'Reg_MSE': 72.5123649982, 'Class_CXE': 690134.7557733753}
Epoch: 482 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690193.9308954909, 'Reg_MSE': 72.4024276829, 'Class_CXE': 690121.528467808}
Epoch: 483 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690180.6013883587, 'Reg_MSE': 72.2930524163, 'Class_CXE': 690108.3083359424}
Epoch: 484 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690541.6319652518, 'Reg_MSE': 72.1842345769, 'Class_CXE': 690469.4477306749}
Epoch: 485 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690528.3057720845, 'Reg_MSE': 72.0759696016, 'Class_CXE': 690456.2298024829}
Epoch: 486 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690514.9755036077, 'Reg_MSE': 71.9682529847, 'Class_CXE': 690443.007250623}
Epoch: 487 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690501.6323792877, 'Reg_MSE': 71.861080277, 'Class_CXE': 690429.7712990107}
Epoch: 488 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690488.3191644332, 'Reg_MSE': 71.7544470845, 'Class_CXE': 690416.5647173487}
Epoch: 489 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690474.9953956818, 'Reg_MSE': 71.648349068, 'Class_CXE': 690403.3470466138}
Epoch: 490 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690461.6631657633, 'Reg_MSE': 71.5427819414, 'Class_CXE': 690390.1203838219}
Epoch: 491 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690448.3399015152, 'Reg_MSE': 71.4377414714, 'Class_CXE': 690376.9021600437}
Epoch: 492 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690435.0054110787, 'Reg_MSE': 71.3332234763, 'Class_CXE': 690363.6721876024}
Epoch: 493 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690421.417477172, 'Reg_MSE': 71.2292238249, 'Class_CXE': 690350.1882533471}
Epoch: 494 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690408.0865310947, 'Reg_MSE': 71.1257384361, 'Class_CXE': 690336.9607926586}
Epoch: 495 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690394.7461420872, 'Reg_MSE': 71.0227632778, 'Class_CXE': 690323.7233788094}
Epoch: 496 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690381.3716237012, 'Reg_MSE': 70.9202943659, 'Class_CXE': 690310.4513293353}
Epoch: 497 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690368.020518271, 'Reg_MSE': 70.8183277639, 'Class_CXE': 690297.2021905071}
Epoch: 498 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690354.687415526, 'Reg_MSE': 70.7168595816, 'Class_CXE': 690283.9705559444}
Epoch: 499 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690341.3612774058, 'Reg_MSE': 70.6158859746, 'Class_CXE': 690270.7453914311}
Epoch: 500 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690328.0337247042, 'Reg_MSE': 70.5154031436, 'Class_CXE': 690257.5183215606}
Epoch: 501 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690314.7071587236, 'Reg_MSE': 70.4154073332, 'Class_CXE': 690244.2917513904}
Epoch: 502 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690301.386337335, 'Reg_MSE': 70.3158948319, 'Class_CXE': 690231.0704425031}
Epoch: 503 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690288.0615454769, 'Reg_MSE': 70.2168619707, 'Class_CXE': 690217.8446835062}
Epoch: 504 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690274.7411456023, 'Reg_MSE': 70.1183051226, 'Class_CXE': 690204.6228404797}
Epoch: 505 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690261.4226612725, 'Reg_MSE': 70.0202207021, 'Class_CXE': 690191.4024405704}
Epoch: 506 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690248.0965043603, 'Reg_MSE': 69.9226051642, 'Class_CXE': 690178.1738991961}
Epoch: 507 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690234.7925286969, 'Reg_MSE': 69.8254550041, 'Class_CXE': 690164.9670736928}
Epoch: 508 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690238.4639312421, 'Reg_MSE': 69.728766756, 'Class_CXE': 690168.7351644861}
Epoch: 509 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690225.1549527806, 'Reg_MSE': 69.6325369933, 'Class_CXE': 690155.5224157873}
Epoch: 510 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690211.8192286578, 'Reg_MSE': 69.5367623269, 'Class_CXE': 690142.2824663308}
Epoch: 511 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690198.3890880087, 'Reg_MSE': 69.4414394056, 'Class_CXE': 690128.9476486031}
Epoch: 512 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690185.0613807064, 'Reg_MSE': 69.3465649148, 'Class_CXE': 690115.7148157916}
Epoch: 513 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690171.7520128975, 'Reg_MSE': 69.2521355762, 'Class_CXE': 690102.4998773213}
Epoch: 514 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690158.4183907494, 'Reg_MSE': 69.1581481474, 'Class_CXE': 690089.260242602}
Epoch: 515 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690145.0963721701, 'Reg_MSE': 69.0645994208, 'Class_CXE': 690076.0317727494}
Epoch: 516 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690131.7830633086, 'Reg_MSE': 68.9714862234, 'Class_CXE': 690062.8115770852}
Epoch: 517 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690118.4701485836, 'Reg_MSE': 68.8788054165, 'Class_CXE': 690049.5913431671}
Epoch: 518 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690105.1539451741, 'Reg_MSE': 68.7865538946, 'Class_CXE': 690036.3673912794}
Epoch: 519 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690091.82140448, 'Reg_MSE': 68.6947285852, 'Class_CXE': 690023.1266758948}
Epoch: 520 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690078.5010386399, 'Reg_MSE': 68.6033264483, 'Class_CXE': 690009.8977121916}
Epoch: 521 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690065.1933775009, 'Reg_MSE': 68.5123444758, 'Class_CXE': 689996.681033025}
Epoch: 522 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690051.8822229482, 'Reg_MSE': 68.4217796911, 'Class_CXE': 689983.4604432571}
Epoch: 523 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690038.5660988891, 'Reg_MSE': 68.3316291486, 'Class_CXE': 689970.2344697405}
Epoch: 524 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690025.2416340618, 'Reg_MSE': 68.2418899329, 'Class_CXE': 689956.9997441289}
Epoch: 525 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690011.9299110065, 'Reg_MSE': 68.152559159, 'Class_CXE': 689943.7773518475}
Epoch: 526 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689998.6701935551, 'Reg_MSE': 68.0636339713, 'Class_CXE': 689930.6065595838}
Epoch: 527 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689985.3916610789, 'Reg_MSE': 67.9751115432, 'Class_CXE': 689917.4165495357}
Epoch: 528 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690346.4419798966, 'Reg_MSE': 67.8869890771, 'Class_CXE': 690278.5549908195}
Epoch: 529 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690333.121630584, 'Reg_MSE': 67.7992638034, 'Class_CXE': 690265.3223667806}
Epoch: 530 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690319.805688408, 'Reg_MSE': 67.7119329804, 'Class_CXE': 690252.0937554275}
Epoch: 531 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690306.4468617042, 'Reg_MSE': 67.6249938939, 'Class_CXE': 690238.8218678103}
Epoch: 532 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690293.1426024231, 'Reg_MSE': 67.5384438566, 'Class_CXE': 690225.6041585664}
Epoch: 533 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690279.8285875014, 'Reg_MSE': 67.4522802078, 'Class_CXE': 690212.3763072937}
Epoch: 534 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690266.5129558829, 'Reg_MSE': 67.3665003132, 'Class_CXE': 690199.1464555698}
Epoch: 535 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690253.1835708932, 'Reg_MSE': 67.2811015642, 'Class_CXE': 690185.9024693291}
Epoch: 536 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690239.868999581, 'Reg_MSE': 67.1960813778, 'Class_CXE': 690172.6729182032}
Epoch: 537 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690226.5674128465, 'Reg_MSE': 67.111437196, 'Class_CXE': 690159.4559756505}
Epoch: 538 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690213.2524339978, 'Reg_MSE': 67.0271664856, 'Class_CXE': 690146.2252675123}
Epoch: 539 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690199.9374408594, 'Reg_MSE': 66.9432667378, 'Class_CXE': 690132.9941741215}
Epoch: 540 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690186.6282826159, 'Reg_MSE': 66.8597354678, 'Class_CXE': 690119.7685471481}
Epoch: 541 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690547.6697373341, 'Reg_MSE': 66.7765702145, 'Class_CXE': 690480.8931671196}
Epoch: 542 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690534.3537070797, 'Reg_MSE': 66.6937685402, 'Class_CXE': 690467.6599385395}
Epoch: 543 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690521.0273096076, 'Reg_MSE': 66.6113280303, 'Class_CXE': 690454.4159815772}
Epoch: 544 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690507.7089530203, 'Reg_MSE': 66.5292462928, 'Class_CXE': 690441.1797067275}
Epoch: 545 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690868.7652587437, 'Reg_MSE': 66.4475209583, 'Class_CXE': 690802.3177377854}
Epoch: 546 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690855.4393836325, 'Reg_MSE': 66.3661496792, 'Class_CXE': 690789.0732339533}
Epoch: 547 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690842.112479376, 'Reg_MSE': 66.28513013, 'Class_CXE': 690775.827349246}
Epoch: 548 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690828.7848321795, 'Reg_MSE': 66.2044600066, 'Class_CXE': 690762.580372173}
Epoch: 549 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690815.5000785211, 'Reg_MSE': 66.1241370259, 'Class_CXE': 690749.3759414952}
Epoch: 550 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690802.1748686391, 'Reg_MSE': 66.044158926, 'Class_CXE': 690736.1307097131}
Epoch: 551 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690788.8467601635, 'Reg_MSE': 65.9645234654, 'Class_CXE': 690722.8822366982}
Epoch: 552 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690775.5242714211, 'Reg_MSE': 65.8852284232, 'Class_CXE': 690709.639042998}
Epoch: 553 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690762.2013959744, 'Reg_MSE': 65.8062715983, 'Class_CXE': 690696.395124376}
Epoch: 554 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690748.8770166169, 'Reg_MSE': 65.7276508097, 'Class_CXE': 690683.1493658072}
Epoch: 555 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690735.559157067, 'Reg_MSE': 65.6493638956, 'Class_CXE': 690669.9097931713}
Epoch: 556 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690722.2440727178, 'Reg_MSE': 65.5714087139, 'Class_CXE': 690656.6726640039}
Epoch: 557 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690708.9147456471, 'Reg_MSE': 65.4937831412, 'Class_CXE': 690643.4209625059}
Epoch: 558 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690695.5902390416, 'Reg_MSE': 65.4164850731, 'Class_CXE': 690630.1737539686}
Epoch: 559 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690682.2793952974, 'Reg_MSE': 65.3395124237, 'Class_CXE': 690616.9398828737}
Epoch: 560 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690668.9594254467, 'Reg_MSE': 65.2628631253, 'Class_CXE': 690603.6965623214}
Epoch: 561 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690655.6340217759, 'Reg_MSE': 65.1865351285, 'Class_CXE': 690590.4474866474}
Epoch: 562 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690642.3182157772, 'Reg_MSE': 65.1105264014, 'Class_CXE': 690577.2076893757}
Epoch: 563 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690628.9638915567, 'Reg_MSE': 65.0348349301, 'Class_CXE': 690563.9290566266}
Epoch: 564 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690615.6218963176, 'Reg_MSE': 64.9594587178, 'Class_CXE': 690550.6624375997}
Epoch: 565 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690602.2955697536, 'Reg_MSE': 64.8843957851, 'Class_CXE': 690537.4111739686}
Epoch: 566 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690588.9776894542, 'Reg_MSE': 64.8096441692, 'Class_CXE': 690524.168045285}
Epoch: 567 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690575.6635084385, 'Reg_MSE': 64.7352019245, 'Class_CXE': 690510.928306514}
Epoch: 568 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690562.3499594872, 'Reg_MSE': 64.6610671216, 'Class_CXE': 690497.6888923657}
Epoch: 569 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690549.02843324, 'Reg_MSE': 64.5872378476, 'Class_CXE': 690484.4411953924}
Epoch: 570 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690535.7182097854, 'Reg_MSE': 64.5137122057, 'Class_CXE': 690471.2044975797}
Epoch: 571 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690522.3985587609, 'Reg_MSE': 64.4404883148, 'Class_CXE': 690457.958070446}
Epoch: 572 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690883.462420262, 'Reg_MSE': 64.36756431, 'Class_CXE': 690819.0948559521}
Epoch: 573 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690870.1261153041, 'Reg_MSE': 64.2949383415, 'Class_CXE': 690805.8311769626}
Epoch: 574 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690856.8163996919, 'Reg_MSE': 64.2226085751, 'Class_CXE': 690792.5937911167}
Epoch: 575 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690843.4997036293, 'Reg_MSE': 64.1505731917, 'Class_CXE': 690779.3491304376}
Epoch: 576 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690830.1777992481, 'Reg_MSE': 64.0788303873, 'Class_CXE': 690766.0989688608}
Epoch: 577 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690816.848417134, 'Reg_MSE': 64.0073783726, 'Class_CXE': 690752.8410387614}
Epoch: 578 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690803.5298837521, 'Reg_MSE': 63.9362153729, 'Class_CXE': 690739.5936683792}
Epoch: 579 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690790.2067316814, 'Reg_MSE': 63.8653396281, 'Class_CXE': 690726.3413920533}
Epoch: 580 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690776.8779795826, 'Reg_MSE': 63.7947493922, 'Class_CXE': 690713.0832301904}
Epoch: 581 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690763.5625899596, 'Reg_MSE': 63.7244429336, 'Class_CXE': 690699.838147026}
Epoch: 582 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690750.2450094423, 'Reg_MSE': 63.6544185345, 'Class_CXE': 690686.5905909078}
Epoch: 583 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690736.8792031141, 'Reg_MSE': 63.5846744909, 'Class_CXE': 690673.2945286232}
Epoch: 584 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690723.6833388547, 'Reg_MSE': 63.5152091125, 'Class_CXE': 690660.1681297421}
Epoch: 585 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690710.363470582, 'Reg_MSE': 63.4460207224, 'Class_CXE': 690646.9174498596}
Epoch: 586 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690697.0345588706, 'Reg_MSE': 63.3771076572, 'Class_CXE': 690633.6574512135}
Epoch: 587 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690683.7133601734, 'Reg_MSE': 63.3084682665, 'Class_CXE': 690620.4048919069}
Epoch: 588 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690670.4005711142, 'Reg_MSE': 63.240100913, 'Class_CXE': 690607.1604702012}
Epoch: 589 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690657.0816706935, 'Reg_MSE': 63.1720039724, 'Class_CXE': 690593.9096667211}
Epoch: 590 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690643.7554334141, 'Reg_MSE': 63.1041758329, 'Class_CXE': 690580.6512575812}
Epoch: 591 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690630.4368798094, 'Reg_MSE': 63.0366148956, 'Class_CXE': 690567.4002649138}
Epoch: 592 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690617.1252376421, 'Reg_MSE': 62.9693195739, 'Class_CXE': 690554.1559180681}
Epoch: 593 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690603.8056683514, 'Reg_MSE': 62.9022882936, 'Class_CXE': 690540.9033800578}
Epoch: 594 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690590.4637527758, 'Reg_MSE': 62.8355194924, 'Class_CXE': 690527.6282332834}
Epoch: 595 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690577.1507350841, 'Reg_MSE': 62.7690116206, 'Class_CXE': 690514.3817234635}
Epoch: 596 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690563.8327657629, 'Reg_MSE': 62.7027631399, 'Class_CXE': 690501.1300026231}
Epoch: 597 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690550.5561353454, 'Reg_MSE': 62.6367725241, 'Class_CXE': 690487.9193628213}
Epoch: 598 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690537.232386781, 'Reg_MSE': 62.5710382586, 'Class_CXE': 690474.6613485224}
Epoch: 599 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690523.9071790826, 'Reg_MSE': 62.5055588403, 'Class_CXE': 690461.4016202423}
Epoch: 600 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690510.5981574297, 'Reg_MSE': 62.4403327776, 'Class_CXE': 690448.1578246522}
Epoch: 601 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690497.2788690561, 'Reg_MSE': 62.3753585902, 'Class_CXE': 690434.903510466}
Epoch: 602 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690483.9789593607, 'Reg_MSE': 62.3106348089, 'Class_CXE': 690421.6683245519}
Epoch: 603 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690470.6564493485, 'Reg_MSE': 62.2461599756, 'Class_CXE': 690408.4102893729}
Epoch: 604 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690457.3390614842, 'Reg_MSE': 62.1819326431, 'Class_CXE': 690395.1571288412}
Epoch: 605 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690444.0293624963, 'Reg_MSE': 62.1179513752, 'Class_CXE': 690381.9114111211}
Epoch: 606 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690430.7117194725, 'Reg_MSE': 62.0542147463, 'Class_CXE': 690368.6575047262}
Epoch: 607 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690417.4034404524, 'Reg_MSE': 61.9907213414, 'Class_CXE': 690355.412719111}
Epoch: 608 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690404.0880261867, 'Reg_MSE': 61.9274697561, 'Class_CXE': 690342.1605564306}
Epoch: 609 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690390.7637392171, 'Reg_MSE': 61.8644585962, 'Class_CXE': 690328.8992806209}
Epoch: 610 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690377.4487775717, 'Reg_MSE': 61.8016864781, 'Class_CXE': 690315.6470910936}
Epoch: 611 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690364.1396537429, 'Reg_MSE': 61.739152028, 'Class_CXE': 690302.4005017149}
Epoch: 612 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690350.8176647776, 'Reg_MSE': 61.6768538825, 'Class_CXE': 690289.1408108951}
Epoch: 613 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690337.4823134164, 'Reg_MSE': 61.6147906881, 'Class_CXE': 690275.8675227283}
Epoch: 614 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690324.165945773, 'Reg_MSE': 61.5529611011, 'Class_CXE': 690262.612984672}
Epoch: 615 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690310.8565202754, 'Reg_MSE': 61.4913637876, 'Class_CXE': 690249.3651564878}
Epoch: 616 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690297.5492695898, 'Reg_MSE': 61.4299974234, 'Class_CXE': 690236.1192721664}
Epoch: 617 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690284.2431621817, 'Reg_MSE': 61.3688606941, 'Class_CXE': 690222.8743014876}
Epoch: 618 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690270.9301398966, 'Reg_MSE': 61.3079522944, 'Class_CXE': 690209.6221876022}
Epoch: 619 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690257.615313966, 'Reg_MSE': 61.2472709286, 'Class_CXE': 690196.3680430374}
Epoch: 620 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690244.3018583407, 'Reg_MSE': 61.1868153105, 'Class_CXE': 690183.1150430302}
Epoch: 621 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690230.9829065377, 'Reg_MSE': 61.1265841628, 'Class_CXE': 690169.8563223749}
Epoch: 622 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690217.673640911, 'Reg_MSE': 61.0665762175, 'Class_CXE': 690156.6070646935}
Epoch: 623 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690204.3571639775, 'Reg_MSE': 61.0067902157, 'Class_CXE': 690143.3503737617}
Epoch: 624 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690191.0805204482, 'Reg_MSE': 60.9472249072, 'Class_CXE': 690130.1332955409}
Epoch: 625 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690177.7750359251, 'Reg_MSE': 60.887879051, 'Class_CXE': 690116.887156874}
Epoch: 626 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690164.4693448787, 'Reg_MSE': 60.8287514146, 'Class_CXE': 690103.6405934641}
Epoch: 627 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690151.157206365, 'Reg_MSE': 60.7698407745, 'Class_CXE': 690090.3873655905}
Epoch: 628 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690137.8243398952, 'Reg_MSE': 60.7111459155, 'Class_CXE': 690077.1131939796}
Epoch: 629 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690124.5095397487, 'Reg_MSE': 60.6526656313, 'Class_CXE': 690063.8568741174}
Epoch: 630 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690111.204435728, 'Reg_MSE': 60.5943987237, 'Class_CXE': 690050.6100370043}
Epoch: 631 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690097.90499127, 'Reg_MSE': 60.536344003, 'Class_CXE': 690037.368647267}
Epoch: 632 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690084.5891941197, 'Reg_MSE': 60.478500288, 'Class_CXE': 690024.1106938317}
Epoch: 633 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690071.2781783649, 'Reg_MSE': 60.4208664056, 'Class_CXE': 690010.8573119593}
Epoch: 634 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690057.982063574, 'Reg_MSE': 60.3634411907, 'Class_CXE': 689997.6186223833}
Epoch: 635 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690044.6745018711, 'Reg_MSE': 60.3062234865, 'Class_CXE': 689984.3682783846}
Epoch: 636 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690031.17407528, 'Reg_MSE': 60.2492121441, 'Class_CXE': 689970.9248631359}
Epoch: 637 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690017.8679830761, 'Reg_MSE': 60.1924060226, 'Class_CXE': 689957.6755770536}
Epoch: 638 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690004.5524379902, 'Reg_MSE': 60.1358039888, 'Class_CXE': 689944.4166340014}
Epoch: 639 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689991.2689846159, 'Reg_MSE': 60.0794049174, 'Class_CXE': 689931.1895796985}
Epoch: 640 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689977.9536720368, 'Reg_MSE': 60.023207691, 'Class_CXE': 689917.9304643457}
Epoch: 641 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689964.6483179608, 'Reg_MSE': 59.9672111994, 'Class_CXE': 689904.6811067614}
Epoch: 642 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689951.3419871796, 'Reg_MSE': 59.9114143405, 'Class_CXE': 689891.4305728391}
Epoch: 643 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689938.028290976, 'Reg_MSE': 59.8558160194, 'Class_CXE': 689878.1724749566}
Epoch: 644 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689924.7252580698, 'Reg_MSE': 59.8004151486, 'Class_CXE': 689864.9248429212}
Epoch: 645 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689911.4253122637, 'Reg_MSE': 59.7452106482, 'Class_CXE': 689851.6801016155}
Epoch: 646 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689898.1177514162, 'Reg_MSE': 59.6902014455, 'Class_CXE': 689838.4275499707}
Epoch: 647 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689884.8192491393, 'Reg_MSE': 59.635386475, 'Class_CXE': 689825.1838626643}
Epoch: 648 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689871.5145750389, 'Reg_MSE': 59.5807646786, 'Class_CXE': 689811.9338103603}
Epoch: 649 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689858.2671814726, 'Reg_MSE': 59.526335005, 'Class_CXE': 689798.7408464677}
Epoch: 650 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689844.9689299664, 'Reg_MSE': 59.4720964103, 'Class_CXE': 689785.4968335561}
Epoch: 651 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689831.6418788515, 'Reg_MSE': 59.4180478574, 'Class_CXE': 689772.2238309941}
Epoch: 652 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689818.3369599353, 'Reg_MSE': 59.3641883163, 'Class_CXE': 689758.972771619}
Epoch: 653 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689805.0808787864, 'Reg_MSE': 59.3105167636, 'Class_CXE': 689745.7703620228}
Epoch: 654 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689791.7813415716, 'Reg_MSE': 59.2570321831, 'Class_CXE': 689732.5243093885}
Epoch: 655 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689778.4751633618, 'Reg_MSE': 59.2037335651, 'Class_CXE': 689719.2714297967}
Epoch: 656 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689765.1817936836, 'Reg_MSE': 59.1506199067, 'Class_CXE': 689706.0311737768}
Epoch: 657 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689751.8789980564, 'Reg_MSE': 59.0976902116, 'Class_CXE': 689692.7813078448}
Epoch: 658 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689738.5770322224, 'Reg_MSE': 59.0449434903, 'Class_CXE': 689679.5320887321}
Epoch: 659 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689725.2659119272, 'Reg_MSE': 58.9923787595, 'Class_CXE': 689666.2735331677}
Epoch: 660 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689711.9294644254, 'Reg_MSE': 58.9399950428, 'Class_CXE': 689652.9894693827}
Epoch: 661 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689698.5947841281, 'Reg_MSE': 58.8877913698, 'Class_CXE': 689639.7069927583}
Epoch: 662 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689685.2969013326, 'Reg_MSE': 58.8357667769, 'Class_CXE': 689626.4611345556}
Epoch: 663 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689672.006271604, 'Reg_MSE': 58.7839203064, 'Class_CXE': 689613.2223512976}
Epoch: 664 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689658.7019284775, 'Reg_MSE': 58.7322510073, 'Class_CXE': 689599.9696774702}
Epoch: 665 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689645.3964594549, 'Reg_MSE': 58.6807579346, 'Class_CXE': 689586.7157015203}
Epoch: 666 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689632.0971521657, 'Reg_MSE': 58.6294401494, 'Class_CXE': 689573.4677120163}
Epoch: 667 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689618.7905054921, 'Reg_MSE': 58.5782967191, 'Class_CXE': 689560.212208773}
Epoch: 668 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689605.4935098519, 'Reg_MSE': 58.5273267171, 'Class_CXE': 689546.9661831347}
Epoch: 669 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689592.1882412634, 'Reg_MSE': 58.4765292228, 'Class_CXE': 689533.7117120406}
Epoch: 670 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689953.2510884937, 'Reg_MSE': 58.4259033216, 'Class_CXE': 689894.8251851721}
Epoch: 671 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689956.9951505159, 'Reg_MSE': 58.3754481048, 'Class_CXE': 689898.6197024111}
Epoch: 672 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689943.6864252297, 'Reg_MSE': 58.3251626697, 'Class_CXE': 689885.36126256}
Epoch: 673 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689930.3871126523, 'Reg_MSE': 58.2750461192, 'Class_CXE': 689872.1120665331}
Epoch: 674 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689917.0715643025, 'Reg_MSE': 58.2250975622, 'Class_CXE': 689858.8464667404}
Epoch: 675 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689903.773966149, 'Reg_MSE': 58.1753161132, 'Class_CXE': 689845.5986500358}
Epoch: 676 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689890.4706484609, 'Reg_MSE': 58.1257008926, 'Class_CXE': 689832.3449475683}
Epoch: 677 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689877.1572928103, 'Reg_MSE': 58.0762510261, 'Class_CXE': 689819.0810417841}
Epoch: 678 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689863.8258892128, 'Reg_MSE': 58.0269656454, 'Class_CXE': 689805.7989235674}
Epoch: 679 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689850.5262950211, 'Reg_MSE': 57.9778438874, 'Class_CXE': 689792.5484511337}
Epoch: 680 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689837.2394249521, 'Reg_MSE': 57.9288848948, 'Class_CXE': 689779.3105400573}
Epoch: 681 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689823.9280292435, 'Reg_MSE': 57.8800878156, 'Class_CXE': 689766.0479414279}
Epoch: 682 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689810.6111731287, 'Reg_MSE': 57.8314518033, 'Class_CXE': 689752.7797213254}
Epoch: 683 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689797.3055637534, 'Reg_MSE': 57.7829760167, 'Class_CXE': 689739.5225877366}
Epoch: 684 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689784.0036143878, 'Reg_MSE': 57.7346596201, 'Class_CXE': 689726.2689547677}
Epoch: 685 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689770.7010089523, 'Reg_MSE': 57.6865017831, 'Class_CXE': 689713.0145071693}
Epoch: 686 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690131.8069610355, 'Reg_MSE': 57.6385016803, 'Class_CXE': 690074.1684593551}
Epoch: 687 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690118.4928846175, 'Reg_MSE': 57.5906584918, 'Class_CXE': 690060.9022261257}
Epoch: 688 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690105.1703311238, 'Reg_MSE': 57.5429714027, 'Class_CXE': 690047.627359721}
Epoch: 689 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690091.8765878308, 'Reg_MSE': 57.4954396035, 'Class_CXE': 690034.3811482274}
Epoch: 690 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690078.652626304, 'Reg_MSE': 57.4480622896, 'Class_CXE': 690021.2045640143}
Epoch: 691 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690065.3461836444, 'Reg_MSE': 57.4008386613, 'Class_CXE': 690007.945344983}
Epoch: 692 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690052.0261827918, 'Reg_MSE': 57.3537679244, 'Class_CXE': 689994.6724148674}
Epoch: 693 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690038.7186993113, 'Reg_MSE': 57.3068492892, 'Class_CXE': 689981.411850022}
Epoch: 694 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690025.410282259, 'Reg_MSE': 57.2600819713, 'Class_CXE': 689968.1502002877}
Epoch: 695 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690012.0940416034, 'Reg_MSE': 57.2134651911, 'Class_CXE': 689954.8805764123}
Epoch: 696 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 689998.7898502575, 'Reg_MSE': 57.1669981737, 'Class_CXE': 689941.6228520838}
Epoch: 697 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690359.8423343465, 'Reg_MSE': 57.1206801492, 'Class_CXE': 690302.7216541973}
Epoch: 698 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690346.5309208552, 'Reg_MSE': 57.0745103526, 'Class_CXE': 690289.4564105027}
Epoch: 699 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690333.2162828649, 'Reg_MSE': 57.0284880236, 'Class_CXE': 690276.1877948414}
Epoch: 700 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690319.897399514, 'Reg_MSE': 56.9826124064, 'Class_CXE': 690262.9147871076}
Epoch: 701 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690680.9442776167, 'Reg_MSE': 56.9368827503, 'Class_CXE': 690624.0073948664}
Epoch: 702 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690667.6157239765, 'Reg_MSE': 56.8912983088, 'Class_CXE': 690610.7244256677}
Epoch: 703 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690654.278715141, 'Reg_MSE': 56.8458583405, 'Class_CXE': 690597.4328568005}
Epoch: 704 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690640.9292094052, 'Reg_MSE': 56.8005621083, 'Class_CXE': 690584.1286472969}
Epoch: 705 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690627.6060273458, 'Reg_MSE': 56.7554088796, 'Class_CXE': 690570.8506184662}
Epoch: 706 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690614.2688977822, 'Reg_MSE': 56.7103979267, 'Class_CXE': 690557.5584998555}
Epoch: 707 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690600.9480865148, 'Reg_MSE': 56.6655285259, 'Class_CXE': 690544.2825579889}
Epoch: 708 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690587.6434072114, 'Reg_MSE': 56.6207999583, 'Class_CXE': 690531.022607253}
Epoch: 709 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 690574.310597282, 'Reg_MSE': 56.5762115092, 'Class_CXE': 690517.7343857728}
Epoch: 710 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690578.0254369495, 'Reg_MSE': 56.5317624687, 'Class_CXE': 690521.4936744807}
Epoch: 711 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690564.705149054, 'Reg_MSE': 56.4874521307, 'Class_CXE': 690508.2176969233}
Epoch: 712 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690551.383350317, 'Reg_MSE': 56.4432797938, 'Class_CXE': 690494.9400705232}
Epoch: 713 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690912.4274518031, 'Reg_MSE': 56.3992447609, 'Class_CXE': 690856.0282070423}
Epoch: 714 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690899.0978160123, 'Reg_MSE': 56.3553463389, 'Class_CXE': 690842.7424696734}
Epoch: 715 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690885.7607687009, 'Reg_MSE': 56.3115838393, 'Class_CXE': 690829.4491848615}
Epoch: 716 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690872.4381206643, 'Reg_MSE': 56.2679565775, 'Class_CXE': 690816.1701640869}
Epoch: 717 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690859.1214354353, 'Reg_MSE': 56.2244638732, 'Class_CXE': 690802.8969715621}
Epoch: 718 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690845.7916476722, 'Reg_MSE': 56.1811050502, 'Class_CXE': 690789.610542622}
Epoch: 719 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690832.4546080356, 'Reg_MSE': 56.1378794366, 'Class_CXE': 690776.316728599}
Epoch: 720 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690819.1286199219, 'Reg_MSE': 56.0947863643, 'Class_CXE': 690763.0338335575}
Epoch: 721 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690805.804336131, 'Reg_MSE': 56.0518251695, 'Class_CXE': 690749.7525109615}
Epoch: 722 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690792.5607458628, 'Reg_MSE': 56.0089951922, 'Class_CXE': 690736.5517506705}
Epoch: 723 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691153.6766188988, 'Reg_MSE': 55.9662957766, 'Class_CXE': 691097.7103231221}
Epoch: 724 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691140.3372257701, 'Reg_MSE': 55.9237262709, 'Class_CXE': 691084.4134994992}
Epoch: 725 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691126.9941268538, 'Reg_MSE': 55.8812860269, 'Class_CXE': 691071.1128408269}
Epoch: 726 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691113.6593096537, 'Reg_MSE': 55.8389744008, 'Class_CXE': 691057.8203352529}
Epoch: 727 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691100.3404755624, 'Reg_MSE': 55.7967907524, 'Class_CXE': 691044.5436848099}
Epoch: 728 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691086.9914113746, 'Reg_MSE': 55.7547344452, 'Class_CXE': 691031.2366769294}
Epoch: 729 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691073.6524864389, 'Reg_MSE': 55.712804847, 'Class_CXE': 691017.9396815918}
Epoch: 730 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691060.1898972131, 'Reg_MSE': 55.671001329, 'Class_CXE': 691004.5188958842}
Epoch: 731 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691421.2170886152, 'Reg_MSE': 55.6293232664, 'Class_CXE': 691365.5877653487}
Epoch: 732 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691407.8809613534, 'Reg_MSE': 55.5877700381, 'Class_CXE': 691352.2931913154}
Epoch: 733 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691394.4933858793, 'Reg_MSE': 55.5463410265, 'Class_CXE': 691338.9470448528}
Epoch: 734 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691381.1349470841, 'Reg_MSE': 55.5050356182, 'Class_CXE': 691325.6299114659}
Epoch: 735 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691367.7996655799, 'Reg_MSE': 55.463853203, 'Class_CXE': 691312.3358123769}
Epoch: 736 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691354.4672616791, 'Reg_MSE': 55.4227931745, 'Class_CXE': 691299.0444685046}
Epoch: 737 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691341.1242710131, 'Reg_MSE': 55.3818549301, 'Class_CXE': 691285.742416083}
Epoch: 738 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691327.8020251027, 'Reg_MSE': 55.3410378706, 'Class_CXE': 691272.460987232}
Epoch: 739 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691314.4584967562, 'Reg_MSE': 55.3003414004, 'Class_CXE': 691259.1581553557}
Epoch: 740 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691301.1207739778, 'Reg_MSE': 55.2597649275, 'Class_CXE': 691245.8610090503}
Epoch: 741 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691287.775416605, 'Reg_MSE': 55.2193078635, 'Class_CXE': 691232.5561087416}
Epoch: 742 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691274.4279060623, 'Reg_MSE': 55.1789696232, 'Class_CXE': 691219.2489364392}
Epoch: 743 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691261.0976412105, 'Reg_MSE': 55.1387496253, 'Class_CXE': 691205.9588915852}
Epoch: 744 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691247.7557047608, 'Reg_MSE': 55.0986472916, 'Class_CXE': 691192.6570574691}
Epoch: 745 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691234.3907930951, 'Reg_MSE': 55.0586620475, 'Class_CXE': 691179.3321310476}
Epoch: 746 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691221.0554267158, 'Reg_MSE': 55.0187933218, 'Class_CXE': 691166.036633394}
Epoch: 747 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691207.7202562235, 'Reg_MSE': 54.9790405467, 'Class_CXE': 691152.7412156768}
Epoch: 748 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691194.3922683564, 'Reg_MSE': 54.9394031576, 'Class_CXE': 691139.4528651988}
Epoch: 749 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691181.0660908287, 'Reg_MSE': 54.8998805934, 'Class_CXE': 691126.1662102353}
Epoch: 750 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691167.7121897938, 'Reg_MSE': 54.8604722964, 'Class_CXE': 691112.8517174975}
Epoch: 751 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691154.3753905423, 'Reg_MSE': 54.8211777119, 'Class_CXE': 691099.5542128304}
Epoch: 752 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691141.0195720609, 'Reg_MSE': 54.7819962887, 'Class_CXE': 691086.2375757722}
Epoch: 753 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691127.6891336526, 'Reg_MSE': 54.7429274787, 'Class_CXE': 691072.9462061739}
Epoch: 754 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691114.3486482381, 'Reg_MSE': 54.7039707372, 'Class_CXE': 691059.6446775009}
Epoch: 755 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691101.0130604756, 'Reg_MSE': 54.6651255226, 'Class_CXE': 691046.3479349529}
Epoch: 756 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691087.6850588068, 'Reg_MSE': 54.6263912965, 'Class_CXE': 691033.0586675103}
Epoch: 757 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691074.3420589585, 'Reg_MSE': 54.5877675236, 'Class_CXE': 691019.7542914349}
Epoch: 758 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691061.0036130156, 'Reg_MSE': 54.5492536718, 'Class_CXE': 691006.4543593439}
Epoch: 759 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691422.0440924659, 'Reg_MSE': 54.5108492121, 'Class_CXE': 691367.5332432538}
Epoch: 760 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691408.8893044711, 'Reg_MSE': 54.4725536187, 'Class_CXE': 691354.4167508524}
Epoch: 761 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691395.5411497583, 'Reg_MSE': 54.4343663686, 'Class_CXE': 691341.1067833897}
Epoch: 762 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691399.2918526762, 'Reg_MSE': 54.3962869423, 'Class_CXE': 691344.8955657339}
Epoch: 763 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691385.9531627257, 'Reg_MSE': 54.3583148228, 'Class_CXE': 691331.5948479029}
Epoch: 764 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691372.6089302315, 'Reg_MSE': 54.3204494966, 'Class_CXE': 691318.2884807349}
Epoch: 765 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691359.2694945551, 'Reg_MSE': 54.2826904529, 'Class_CXE': 691304.9868041022}
Epoch: 766 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691719.9565212526, 'Reg_MSE': 54.245037184, 'Class_CXE': 691665.7114840687}
Epoch: 767 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691706.6009339449, 'Reg_MSE': 54.207489185, 'Class_CXE': 691652.3934447599}
Epoch: 768 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691693.257099569, 'Reg_MSE': 54.1700459543, 'Class_CXE': 691639.0870536147}
Epoch: 769 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691679.9027249179, 'Reg_MSE': 54.1327069927, 'Class_CXE': 691625.7700179252}
Epoch: 770 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691666.5653540668, 'Reg_MSE': 54.0954718044, 'Class_CXE': 691612.4698822624}
Epoch: 771 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691653.2131210612, 'Reg_MSE': 54.0583398961, 'Class_CXE': 691599.1547811651}
Epoch: 772 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691639.867069947, 'Reg_MSE': 54.0213107775, 'Class_CXE': 691585.8457591694}
Epoch: 773 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691626.5114512901, 'Reg_MSE': 53.9843839613, 'Class_CXE': 691572.5270673288}
Epoch: 774 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691613.1399633733, 'Reg_MSE': 53.9475589627, 'Class_CXE': 691559.1924044106}
Epoch: 775 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691599.7851339615, 'Reg_MSE': 53.9108353, 'Class_CXE': 691545.8742986615}
Epoch: 776 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691586.4275785986, 'Reg_MSE': 53.8742124941, 'Class_CXE': 691532.5533661045}
Epoch: 777 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691573.0854502124, 'Reg_MSE': 53.8376900687, 'Class_CXE': 691519.2477601436}
Epoch: 778 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691559.7367345913, 'Reg_MSE': 53.8012675503, 'Class_CXE': 691505.935467041}
Epoch: 779 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691546.3795744695, 'Reg_MSE': 53.7649444682, 'Class_CXE': 691492.6146300014}
Epoch: 780 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691907.3915496002, 'Reg_MSE': 53.7287203542, 'Class_CXE': 691853.662829246}
Epoch: 781 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691894.0314061191, 'Reg_MSE': 53.6925947429, 'Class_CXE': 691840.3388113762}
Epoch: 782 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691880.6869898727, 'Reg_MSE': 53.6565671717, 'Class_CXE': 691827.030422701}
Epoch: 783 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691867.2350979155, 'Reg_MSE': 53.6206371805, 'Class_CXE': 691813.614460735}
Epoch: 784 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691853.8371638827, 'Reg_MSE': 53.5848043119, 'Class_CXE': 691800.2523595708}
Epoch: 785 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691840.4856126939, 'Reg_MSE': 53.549068111, 'Class_CXE': 691786.936544583}
Epoch: 786 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691827.1297747236, 'Reg_MSE': 53.5134281258, 'Class_CXE': 691773.6163465978}
Epoch: 787 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691813.7680513759, 'Reg_MSE': 53.4778839067, 'Class_CXE': 691760.2901674692}
Epoch: 788 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691800.4125664937, 'Reg_MSE': 53.4424350066, 'Class_CXE': 691746.9701314871}
Epoch: 789 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691787.0537082608, 'Reg_MSE': 53.4070809811, 'Class_CXE': 691733.6466272797}
Epoch: 790 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691773.6912658813, 'Reg_MSE': 53.3718213883, 'Class_CXE': 691720.319444493}
Epoch: 791 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691760.3281894451, 'Reg_MSE': 53.3366557888, 'Class_CXE': 691706.9915336563}
Epoch: 792 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691746.9823259743, 'Reg_MSE': 53.3015837458, 'Class_CXE': 691693.6807422285}
Epoch: 793 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691733.6207422527, 'Reg_MSE': 53.2666048248, 'Class_CXE': 691680.3541374279}
Epoch: 794 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691720.2815895255, 'Reg_MSE': 53.2317185939, 'Class_CXE': 691667.0498709317}
Epoch: 795 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691706.9288275266, 'Reg_MSE': 53.1969246238, 'Class_CXE': 691653.7319029028}
Epoch: 796 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691693.5551212934, 'Reg_MSE': 53.1622224874, 'Class_CXE': 691640.392898806}
Epoch: 797 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691680.1968910358, 'Reg_MSE': 53.1276117601, 'Class_CXE': 691627.0692792757}
Epoch: 798 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691666.838937945, 'Reg_MSE': 53.0930920199, 'Class_CXE': 691613.7458459251}
Epoch: 799 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691653.4697912702, 'Reg_MSE': 53.0586628468, 'Class_CXE': 691600.4111284234}
Epoch: 800 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691640.0862628951, 'Reg_MSE': 53.0243238237, 'Class_CXE': 691587.0619390714}
Epoch: 801 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691626.6205896527, 'Reg_MSE': 52.9900745353, 'Class_CXE': 691573.6305151174}
Epoch: 802 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691613.2655098704, 'Reg_MSE': 52.9559145692, 'Class_CXE': 691560.3095953012}
Epoch: 803 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691599.9139108465, 'Reg_MSE': 52.921843515, 'Class_CXE': 691546.9920673314}
Epoch: 804 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691586.5581250449, 'Reg_MSE': 52.8878609646, 'Class_CXE': 691533.6702640803}
Epoch: 805 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691573.1957997322, 'Reg_MSE': 52.8539665125, 'Class_CXE': 691520.3418332197}
Epoch: 806 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691559.8480999471, 'Reg_MSE': 52.8201597551, 'Class_CXE': 691507.027940192}
Epoch: 807 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691546.5058320204, 'Reg_MSE': 52.7864402915, 'Class_CXE': 691493.719391729}
Epoch: 808 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691533.1480177966, 'Reg_MSE': 52.7528077227, 'Class_CXE': 691480.3952100739}
Epoch: 809 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691519.9593110108, 'Reg_MSE': 52.7192616521, 'Class_CXE': 691467.2400493587}
Epoch: 810 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691506.5988072566, 'Reg_MSE': 52.6858016854, 'Class_CXE': 691453.9130055711}
Epoch: 811 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691493.242651769, 'Reg_MSE': 52.6524274304, 'Class_CXE': 691440.5902243386}
Epoch: 812 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691479.89191024, 'Reg_MSE': 52.6191384972, 'Class_CXE': 691427.2727717429}
Epoch: 813 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691466.5294425272, 'Reg_MSE': 52.585934498, 'Class_CXE': 691413.9435080292}
Epoch: 814 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691453.173262893, 'Reg_MSE': 52.5528150473, 'Class_CXE': 691400.6204478457}
Epoch: 815 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691439.8279391084, 'Reg_MSE': 52.5197797617, 'Class_CXE': 691387.3081593467}
Epoch: 816 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691426.4691831828, 'Reg_MSE': 52.4868282599, 'Class_CXE': 691373.9823549229}
Epoch: 817 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691413.2031263605, 'Reg_MSE': 52.4539601628, 'Class_CXE': 691360.7491661977}
Epoch: 818 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691399.8460591316, 'Reg_MSE': 52.4211750934, 'Class_CXE': 691347.4248840382}
Epoch: 819 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691386.4941602641, 'Reg_MSE': 52.3884726768, 'Class_CXE': 691334.1056875874}
Epoch: 820 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691373.1281066028, 'Reg_MSE': 52.3558525403, 'Class_CXE': 691320.7722540625}
Epoch: 821 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691359.7763370294, 'Reg_MSE': 52.3233143132, 'Class_CXE': 691307.4530227162}
Epoch: 822 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691346.4222788274, 'Reg_MSE': 52.2908576268, 'Class_CXE': 691294.1314212006}
Epoch: 823 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691333.0683032353, 'Reg_MSE': 52.2584821146, 'Class_CXE': 691280.8098211207}
Epoch: 824 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691319.7191946525, 'Reg_MSE': 52.2261874121, 'Class_CXE': 691267.4930072405}
Epoch: 825 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691306.3733503479, 'Reg_MSE': 52.1939731568, 'Class_CXE': 691254.179377191}
Epoch: 826 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691293.0144762694, 'Reg_MSE': 52.1618389882, 'Class_CXE': 691240.8526372812}
Epoch: 827 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691279.657560525, 'Reg_MSE': 52.1297845478, 'Class_CXE': 691227.5277759772}
Epoch: 828 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691266.302993921, 'Reg_MSE': 52.0978094793, 'Class_CXE': 691214.2051844416}
Epoch: 829 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691252.9467431345, 'Reg_MSE': 52.065913428, 'Class_CXE': 691200.8808297066}
Epoch: 830 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691239.5920613164, 'Reg_MSE': 52.0340960416, 'Class_CXE': 691187.5579652748}
Epoch: 831 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691226.2472145834, 'Reg_MSE': 52.0023569695, 'Class_CXE': 691174.2448576139}
Epoch: 832 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691212.8799074201, 'Reg_MSE': 51.970695863, 'Class_CXE': 691160.9092115571}
Epoch: 833 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691199.5484935299, 'Reg_MSE': 51.9391123755, 'Class_CXE': 691147.6093811544}
Epoch: 834 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691186.1919013906, 'Reg_MSE': 51.9076061622, 'Class_CXE': 691134.2842952283}
Epoch: 835 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691172.836629248, 'Reg_MSE': 51.8761768803, 'Class_CXE': 691120.9604523677}
Epoch: 836 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691159.481015036, 'Reg_MSE': 51.8448241889, 'Class_CXE': 691107.6361908471}
Epoch: 837 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691146.1253102506, 'Reg_MSE': 51.8135477489, 'Class_CXE': 691094.3117625016}
Epoch: 838 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691132.7588065286, 'Reg_MSE': 51.7823472231, 'Class_CXE': 691080.9764593055}
Epoch: 839 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691119.412709596, 'Reg_MSE': 51.7512222762, 'Class_CXE': 691067.6614873197}
Epoch: 840 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691106.0652899164, 'Reg_MSE': 51.7201725748, 'Class_CXE': 691054.3451173416}
Epoch: 841 - {'Class_acc': 0.5253635734072022, 'CXE+MSE': 691092.5774195637, 'Reg_MSE': 51.6891977871, 'Class_CXE': 691040.8882217766}
Epoch: 842 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691096.2586755245, 'Reg_MSE': 51.6582975835, 'Class_CXE': 691044.600377941}
Epoch: 843 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691082.9068597086, 'Reg_MSE': 51.6274716358, 'Class_CXE': 691031.2793880728}
Epoch: 844 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691069.548417896, 'Reg_MSE': 51.5967196181, 'Class_CXE': 691017.951698278}
Epoch: 845 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691056.1949330688, 'Reg_MSE': 51.5660412057, 'Class_CXE': 691004.6288918632}
Epoch: 846 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691042.84273277, 'Reg_MSE': 51.5354360763, 'Class_CXE': 690991.3072966937}
Epoch: 847 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691029.5189490813, 'Reg_MSE': 51.5049039089, 'Class_CXE': 690978.0140451724}
Epoch: 848 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691016.1742339276, 'Reg_MSE': 51.4744443844, 'Class_CXE': 690964.6997895432}
Epoch: 849 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691002.8227429545, 'Reg_MSE': 51.4440571856, 'Class_CXE': 690951.3786857689}
Epoch: 850 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690989.4667104629, 'Reg_MSE': 51.4137419969, 'Class_CXE': 690938.0529684661}
Epoch: 851 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690976.1232878168, 'Reg_MSE': 51.3834985044, 'Class_CXE': 690924.7397893124}
Epoch: 852 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690962.7587591251, 'Reg_MSE': 51.3533263961, 'Class_CXE': 690911.405432729}
Epoch: 853 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690949.4093718193, 'Reg_MSE': 51.3232253614, 'Class_CXE': 690898.0861464579}
Epoch: 854 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690936.0579029027, 'Reg_MSE': 51.2931950918, 'Class_CXE': 690884.764707811}
Epoch: 855 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690922.6856729072, 'Reg_MSE': 51.26323528, 'Class_CXE': 690871.4224376271}
Epoch: 856 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690909.3412805173, 'Reg_MSE': 51.2333456209, 'Class_CXE': 690858.1079348964}
Epoch: 857 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690895.8905080303, 'Reg_MSE': 51.2035258106, 'Class_CXE': 690844.6869822197}
Epoch: 858 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690882.5422554592, 'Reg_MSE': 51.1737755473, 'Class_CXE': 690831.3684799119}
Epoch: 859 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690869.1858053446, 'Reg_MSE': 51.1440945304, 'Class_CXE': 690818.0417108142}
Epoch: 860 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690855.838094434, 'Reg_MSE': 51.1144824612, 'Class_CXE': 690804.7236119728}
Epoch: 861 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690842.508985391, 'Reg_MSE': 51.0849390427, 'Class_CXE': 690791.4240463483}
Epoch: 862 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690829.1312362887, 'Reg_MSE': 51.0554639792, 'Class_CXE': 690778.0757723094}
Epoch: 863 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690815.8626527606, 'Reg_MSE': 51.026056977, 'Class_CXE': 690764.8365957836}
Epoch: 864 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691176.8893808338, 'Reg_MSE': 50.9967177437, 'Class_CXE': 691125.8926630901}
Epoch: 865 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691163.5310272089, 'Reg_MSE': 50.9674459885, 'Class_CXE': 691112.5635812203}
Epoch: 866 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691150.168858198, 'Reg_MSE': 50.9382414224, 'Class_CXE': 691099.2306167756}
Epoch: 867 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691136.8178227424, 'Reg_MSE': 50.9091037578, 'Class_CXE': 691085.9087189846}
Epoch: 868 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691123.4644657086, 'Reg_MSE': 50.8800327086, 'Class_CXE': 691072.584433}
Epoch: 869 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691110.107110397, 'Reg_MSE': 50.8510279904, 'Class_CXE': 691059.2560824066}
Epoch: 870 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691096.7505156319, 'Reg_MSE': 50.8220893203, 'Class_CXE': 691045.9284263116}
Epoch: 871 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691083.3957782593, 'Reg_MSE': 50.7932164169, 'Class_CXE': 691032.6025618424}
Epoch: 872 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691070.0338082077, 'Reg_MSE': 50.7644090003, 'Class_CXE': 691019.2693992074}
Epoch: 873 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691056.6827259519, 'Reg_MSE': 50.7356667921, 'Class_CXE': 691005.9470591597}
Epoch: 874 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691043.3296156361, 'Reg_MSE': 50.7069895155, 'Class_CXE': 690992.6226261207}
Epoch: 875 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691029.9795337929, 'Reg_MSE': 50.6783768951, 'Class_CXE': 690979.3011568978}
Epoch: 876 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691016.6185914242, 'Reg_MSE': 50.6498286571, 'Class_CXE': 690965.9687627672}
Epoch: 877 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691003.2755001888, 'Reg_MSE': 50.6213445289, 'Class_CXE': 690952.65415566}
Epoch: 878 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690989.9059493196, 'Reg_MSE': 50.5929242398, 'Class_CXE': 690939.3130250798}
Epoch: 879 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690976.5483768459, 'Reg_MSE': 50.5645675202, 'Class_CXE': 690925.9838093257}
Epoch: 880 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690963.1892071137, 'Reg_MSE': 50.536274102, 'Class_CXE': 690912.6529330116}
Epoch: 881 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690949.8353175529, 'Reg_MSE': 50.5080437187, 'Class_CXE': 690899.3272738341}
Epoch: 882 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 690936.4875894991, 'Reg_MSE': 50.4798761052, 'Class_CXE': 690886.0077133939}
Epoch: 883 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691297.5013553008, 'Reg_MSE': 50.4517709976, 'Class_CXE': 691247.0495843032}
Epoch: 884 - {'Class_acc': 0.5254501385041551, 'CXE+MSE': 691284.1431157023, 'Reg_MSE': 50.4237281336, 'Class_CXE': 691233.7193875688}
Epoch: 885 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691270.7808322356, 'Reg_MSE': 50.3957472524, 'Class_CXE': 691220.3850849832}
Epoch: 886 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691257.4156445342, 'Reg_MSE': 50.3678280943, 'Class_CXE': 691207.0478164399}
Epoch: 887 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691244.0491399282, 'Reg_MSE': 50.3399704014, 'Class_CXE': 691193.7091695268}
Epoch: 888 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691230.6672581477, 'Reg_MSE': 50.3121739167, 'Class_CXE': 691180.355084231}
Epoch: 889 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691217.3046283519, 'Reg_MSE': 50.284438385, 'Class_CXE': 691167.0201899669}
Epoch: 890 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691578.3085342171, 'Reg_MSE': 50.2567635522, 'Class_CXE': 691528.0517706649}
Epoch: 891 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691564.9519645165, 'Reg_MSE': 50.2291491657, 'Class_CXE': 691514.7228153509}
Epoch: 892 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691551.5861276523, 'Reg_MSE': 50.2015949741, 'Class_CXE': 691501.3845326782}
Epoch: 893 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691538.2385467363, 'Reg_MSE': 50.1741007275, 'Class_CXE': 691488.0644460088}
Epoch: 894 - {'Class_acc': 0.525536703601108, 'CXE+MSE': 691524.8709454404, 'Reg_MSE': 50.1466661772, 'Class_CXE': 691474.7242792632}
Epoch: 895 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691528.5242923462, 'Reg_MSE': 50.1192910759, 'Class_CXE': 691478.4050012704}
Epoch: 896 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691889.5261185155, 'Reg_MSE': 50.0919751776, 'Class_CXE': 691839.4341433379}
Epoch: 897 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691876.1537465798, 'Reg_MSE': 50.0647182376, 'Class_CXE': 691826.0890283422}
Epoch: 898 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691862.7816024388, 'Reg_MSE': 50.0375200125, 'Class_CXE': 691812.7440824263}
Epoch: 899 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691849.404886742, 'Reg_MSE': 50.0103802602, 'Class_CXE': 691799.3945064818}
Epoch: 900 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691836.046318543, 'Reg_MSE': 49.9832987399, 'Class_CXE': 691786.0630198031}
Epoch: 901 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691822.6689662877, 'Reg_MSE': 49.956275212, 'Class_CXE': 691772.7126910758}
Epoch: 902 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691809.3003998039, 'Reg_MSE': 49.9293094384, 'Class_CXE': 691759.3710903655}
Epoch: 903 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691795.9290004239, 'Reg_MSE': 49.9024011818, 'Class_CXE': 691746.0265992421}
Epoch: 904 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691782.4843040133, 'Reg_MSE': 49.8755502067, 'Class_CXE': 691732.6087538066}
Epoch: 905 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692143.4831695464, 'Reg_MSE': 49.8487562785, 'Class_CXE': 692093.6344132678}
Epoch: 906 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692130.09803289, 'Reg_MSE': 49.8220191639, 'Class_CXE': 692080.2760137261}
Epoch: 907 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692116.7099904, 'Reg_MSE': 49.7953386309, 'Class_CXE': 692066.9146517691}
Epoch: 908 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692103.3316315196, 'Reg_MSE': 49.7687144487, 'Class_CXE': 692053.5629170709}
Epoch: 909 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692089.9873743629, 'Reg_MSE': 49.7421463876, 'Class_CXE': 692040.2452279753}
Epoch: 910 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692076.6093802112, 'Reg_MSE': 49.7156342194, 'Class_CXE': 692026.8937459919}
Epoch: 911 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692063.2283771392, 'Reg_MSE': 49.6891777167, 'Class_CXE': 692013.5391994225}
Epoch: 912 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692049.8555596157, 'Reg_MSE': 49.6627766536, 'Class_CXE': 692000.192782962}
Epoch: 913 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692036.564996954, 'Reg_MSE': 49.6364308054, 'Class_CXE': 691986.9285661486}
Epoch: 914 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692023.1800797154, 'Reg_MSE': 49.6101399484, 'Class_CXE': 691973.569939767}
Epoch: 915 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 692009.7993465773, 'Reg_MSE': 49.5839038602, 'Class_CXE': 691960.2154427171}
Epoch: 916 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691996.425669702, 'Reg_MSE': 49.5577223194, 'Class_CXE': 691946.8679473826}
Epoch: 917 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691983.041498227, 'Reg_MSE': 49.531595106, 'Class_CXE': 691933.509903121}
Epoch: 918 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691969.6693469781, 'Reg_MSE': 49.5055220011, 'Class_CXE': 691920.163824977}
Epoch: 919 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691956.2479201466, 'Reg_MSE': 49.4795027867, 'Class_CXE': 691906.7684173599}
Epoch: 920 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691942.8674576606, 'Reg_MSE': 49.4535372463, 'Class_CXE': 691893.4139204143}
Epoch: 921 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691929.485493588, 'Reg_MSE': 49.4276251642, 'Class_CXE': 691880.0578684239}
Epoch: 922 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691916.1050623882, 'Reg_MSE': 49.4017663262, 'Class_CXE': 691866.7032960621}
Epoch: 923 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691902.723699687, 'Reg_MSE': 49.3759605189, 'Class_CXE': 691853.3477391681}
Epoch: 924 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691889.3334096444, 'Reg_MSE': 49.3502075301, 'Class_CXE': 691839.9832021144}
Epoch: 925 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691875.9572091756, 'Reg_MSE': 49.3245071487, 'Class_CXE': 691826.6327020269}
Epoch: 926 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691862.5807553892, 'Reg_MSE': 49.2988591648, 'Class_CXE': 691813.2818962244}
Epoch: 927 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691849.2218158732, 'Reg_MSE': 49.2732633696, 'Class_CXE': 691799.9485525036}
Epoch: 928 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691835.9578055562, 'Reg_MSE': 49.2477195551, 'Class_CXE': 691786.7100860012}
Epoch: 929 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691822.5822610002, 'Reg_MSE': 49.2222275148, 'Class_CXE': 691773.3600334853}
Epoch: 930 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691809.2097096229, 'Reg_MSE': 49.196787043, 'Class_CXE': 691760.0129225799}
Epoch: 931 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691795.8203999366, 'Reg_MSE': 49.1713979352, 'Class_CXE': 691746.6490020014}
Epoch: 932 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691782.4367145183, 'Reg_MSE': 49.1460599877, 'Class_CXE': 691733.2906545306}
Epoch: 933 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691769.0583864015, 'Reg_MSE': 49.1207729983, 'Class_CXE': 691719.9376134032}
Epoch: 934 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691755.6758146773, 'Reg_MSE': 49.0955367655, 'Class_CXE': 691706.5802779117}
Epoch: 935 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691742.3048473728, 'Reg_MSE': 49.0703510889, 'Class_CXE': 691693.2344962839}
Epoch: 936 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691728.9273388336, 'Reg_MSE': 49.0452157693, 'Class_CXE': 691679.8821230642}
Epoch: 937 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691715.5533340712, 'Reg_MSE': 49.0201306084, 'Class_CXE': 691666.5332034628}
Epoch: 938 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691702.188109301, 'Reg_MSE': 48.9950954089, 'Class_CXE': 691653.1930138922}
Epoch: 939 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691688.8098173919, 'Reg_MSE': 48.9701099746, 'Class_CXE': 691639.8397074173}
Epoch: 940 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691675.4201799817, 'Reg_MSE': 48.9451741102, 'Class_CXE': 691626.4750058715}
Epoch: 941 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691662.0469146753, 'Reg_MSE': 48.9202876217, 'Class_CXE': 691613.1266270536}
Epoch: 942 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691648.670999622, 'Reg_MSE': 48.8954503157, 'Class_CXE': 691599.7755493063}
Epoch: 943 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691635.2967395678, 'Reg_MSE': 48.870662, 'Class_CXE': 691586.4260775677}
Epoch: 944 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691621.8545129572, 'Reg_MSE': 48.8459224835, 'Class_CXE': 691573.0085904737}
Epoch: 945 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691608.5091931273, 'Reg_MSE': 48.8212315759, 'Class_CXE': 691559.6879615514}
Epoch: 946 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691595.1284509631, 'Reg_MSE': 48.7965890879, 'Class_CXE': 691546.3318618751}
Epoch: 947 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691581.7535819795, 'Reg_MSE': 48.7719948313, 'Class_CXE': 691532.9815871483}
Epoch: 948 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691568.3738442064, 'Reg_MSE': 48.7474486188, 'Class_CXE': 691519.6263955877}
Epoch: 949 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691554.9744253252, 'Reg_MSE': 48.7229502639, 'Class_CXE': 691506.2514750613}
Epoch: 950 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691541.6016551114, 'Reg_MSE': 48.6984995813, 'Class_CXE': 691492.9031555301}
Epoch: 951 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691902.5888465594, 'Reg_MSE': 48.6740963865, 'Class_CXE': 691853.9147501729}
Epoch: 952 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691889.2023929838, 'Reg_MSE': 48.6497404961, 'Class_CXE': 691840.5526524878}
Epoch: 953 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691875.8134753166, 'Reg_MSE': 48.6254317275, 'Class_CXE': 691827.1880435891}
Epoch: 954 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691862.4291522212, 'Reg_MSE': 48.601169899, 'Class_CXE': 691813.8279823222}
Epoch: 955 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691849.0460078304, 'Reg_MSE': 48.57695483, 'Class_CXE': 691800.4690530004}
Epoch: 956 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691835.6560131607, 'Reg_MSE': 48.5527863407, 'Class_CXE': 691787.10322682}
Epoch: 957 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691822.2725140941, 'Reg_MSE': 48.5286642522, 'Class_CXE': 691773.7438498419}
Epoch: 958 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691808.8917229808, 'Reg_MSE': 48.5045883866, 'Class_CXE': 691760.3871345941}
Epoch: 959 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691795.5086726886, 'Reg_MSE': 48.4805585668, 'Class_CXE': 691747.0281141219}
Epoch: 960 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691782.1252927368, 'Reg_MSE': 48.4565746167, 'Class_CXE': 691733.6687181201}
Epoch: 961 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691768.7419991038, 'Reg_MSE': 48.4326363612, 'Class_CXE': 691720.3093627426}
Epoch: 962 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691755.3629076775, 'Reg_MSE': 48.4087436257, 'Class_CXE': 691706.9541640518}
Epoch: 963 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691741.9776736184, 'Reg_MSE': 48.384896237, 'Class_CXE': 691693.5927773813}
Epoch: 964 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691728.5787134734, 'Reg_MSE': 48.3610940223, 'Class_CXE': 691680.2176194511}
Epoch: 965 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691715.2233691119, 'Reg_MSE': 48.3373368101, 'Class_CXE': 691666.8860323018}
Epoch: 966 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691701.8351717421, 'Reg_MSE': 48.3136244294, 'Class_CXE': 691653.5215473127}
Epoch: 967 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691705.4719995718, 'Reg_MSE': 48.2899567104, 'Class_CXE': 691657.1820428614}
Epoch: 968 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691692.0911503009, 'Reg_MSE': 48.2663334839, 'Class_CXE': 691643.8248168171}
Epoch: 969 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691678.7244373319, 'Reg_MSE': 48.2427545816, 'Class_CXE': 691630.4816827503}
Epoch: 970 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691665.3414948776, 'Reg_MSE': 48.2192198363, 'Class_CXE': 691617.1222750413}
Epoch: 971 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691651.955564021, 'Reg_MSE': 48.1957290813, 'Class_CXE': 691603.7598349397}
Epoch: 972 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691638.5678615207, 'Reg_MSE': 48.1722821509, 'Class_CXE': 691590.3955793697}
Epoch: 973 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691625.1874599184, 'Reg_MSE': 48.1488788803, 'Class_CXE': 691577.0385810381}
Epoch: 974 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691611.7727591763, 'Reg_MSE': 48.1255191055, 'Class_CXE': 691563.6472400709}
Epoch: 975 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691598.389102102, 'Reg_MSE': 48.1022026632, 'Class_CXE': 691550.2868994388}
Epoch: 976 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691585.0104660719, 'Reg_MSE': 48.0789293911, 'Class_CXE': 691536.9315366808}
Epoch: 977 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691571.7606993889, 'Reg_MSE': 48.0556991276, 'Class_CXE': 691523.7050002613}
Epoch: 978 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691558.3553840568, 'Reg_MSE': 48.0325117119, 'Class_CXE': 691510.3228723449}
Epoch: 979 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691544.9751051689, 'Reg_MSE': 48.0093669842, 'Class_CXE': 691496.9657381846}
Epoch: 980 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691531.5458374609, 'Reg_MSE': 47.9862647852, 'Class_CXE': 691483.5595726757}
Epoch: 981 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691518.1507561733, 'Reg_MSE': 47.9632049566, 'Class_CXE': 691470.1875512167}
Epoch: 982 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691504.7680323911, 'Reg_MSE': 47.940187341, 'Class_CXE': 691456.8278450501}
Epoch: 983 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691491.3860326748, 'Reg_MSE': 47.9172117815, 'Class_CXE': 691443.4688208933}
Epoch: 984 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691478.0042086475, 'Reg_MSE': 47.8942781221, 'Class_CXE': 691430.1099305254}
Epoch: 985 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691464.6498318041, 'Reg_MSE': 47.8713862078, 'Class_CXE': 691416.7784455963}
Epoch: 986 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691451.2670545409, 'Reg_MSE': 47.848535884, 'Class_CXE': 691403.4185186569}
Epoch: 987 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691437.889168995, 'Reg_MSE': 47.8257269972, 'Class_CXE': 691390.0634419977}
Epoch: 988 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691424.506316448, 'Reg_MSE': 47.8029593944, 'Class_CXE': 691376.7033570536}
Epoch: 989 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691411.1200387222, 'Reg_MSE': 47.7802329237, 'Class_CXE': 691363.3398057985}
Epoch: 990 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691772.1016473522, 'Reg_MSE': 47.7575474335, 'Class_CXE': 691724.3440999187}
Epoch: 991 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691758.7170396735, 'Reg_MSE': 47.7349027734, 'Class_CXE': 691710.9821369001}
Epoch: 992 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691745.3333904129, 'Reg_MSE': 47.7122987935, 'Class_CXE': 691697.6210916194}
Epoch: 993 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691731.9432310339, 'Reg_MSE': 47.6897353447, 'Class_CXE': 691684.2534956891}
Epoch: 994 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691718.554345367, 'Reg_MSE': 47.6672122787, 'Class_CXE': 691670.8871330883}
Epoch: 995 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691705.1658094267, 'Reg_MSE': 47.6447294478, 'Class_CXE': 691657.5210799789}
Epoch: 996 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691691.7660192155, 'Reg_MSE': 47.6222867053, 'Class_CXE': 691644.1437325102}
Epoch: 997 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691678.3741547985, 'Reg_MSE': 47.5998839048, 'Class_CXE': 691630.7742708937}
Epoch: 998 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691664.9886220257, 'Reg_MSE': 47.577520901, 'Class_CXE': 691617.4111011247}
Epoch: 999 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691651.5988260263, 'Reg_MSE': 47.5551975492, 'Class_CXE': 691604.0436284771}
Epoch: 1000 - {'Class_acc': 0.525623268698061, 'CXE+MSE': 691638.2151663445, 'Reg_MSE': 47.5329137055, 'Class_CXE': 691590.6822526391}
np.random.seed(42)
from sklearn.linear_model import LogisticRegression
model_lr_sklearn = LogisticRegression( C=1e6, solver="sag", max_iter=15)
model_lr_sklearn.fit(X_train_c,y_train_c)
LogisticRegression(C=1000000.0, max_iter=15, solver='sag')
predicted_c=model_lr_sklearn.predict(X_test_c)
acc = accuracy_score(y_test_c, predicted_c)
acc
0.49230769230769234
In order to implement the homegrown version of logistic regression to classify and regress at the same time, we created a class called HomeGrownLogisticRegression that consists of important methods instrumental in training a model.
The input into the model was a 32x32x3 flattened numpy array of all the images.
Using gradient updating of weights, we maintained a theta matrix that learnt the weights on how to classify an image as cat or dog and at the same time we maintained another theta matrix that learns the weights to predict the Xmin, Xmax, Ymin and Ymax values using unnormalized distances in logistic regression, aka linear regression.
We observed that without a learning rate scheduler, we had to use a very small lr value to observe the CXE+MSE values reducing.
We obtained a classification accuracy of 52.7% on the validation data and a MSE on its way to convergence at the end of 1000 epochs.
In this series we'll be building machine learning models (specifically, neural networks) to perform image classification using PyTorch and Torchvision.
In this first notebook, we'll start with one of the most basic neural network architectures, a multilayer perceptron (MLP), also known as a feedforward network. The dataset we'll be using is the famous MNIST dataset, a dataset of 28x28 black and white images consisting of handwritten digits, 0 to 9.
We'll process the dataset, build our model and then train our model. Afterwards we'll do a short dive into what the model has actually learned.
Let's start by importing all of the modules we'll need. The main ones we need to import are:
from collections import Counter
import glob
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from PIL import Image
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import SGDClassifier, SGDRegressor
from sklearn.metrics import accuracy_score, mean_squared_error, roc_auc_score
from sklearn.model_selection import train_test_split
import tarfile
from tqdm.notebook import tqdm
import warnings
def extract_tar(file, path):
"""
function to extract tar.gz files to specified location
Args:
file (str): path where the file is located
path (str): path where you want to extract
"""
with tarfile.open(file) as tar:
files_extracted = 0
for member in tqdm(tar.getmembers()):
if os.path.isfile(path + member.name[1:]):
continue
else:
tar.extract(member, path)
files_extracted += 1
tar.close()
if files_extracted < 3:
print('Files already exist')
path = 'images/'
extract_tar('/content/drive/MyDrive/AML_Project/cadod.tar.gz', path)
Files already exist
df = pd.read_csv('/content/drive/MyDrive/AML_Project/cadod.csv')
df.head()
| ImageID | Source | LabelName | Confidence | XMin | XMax | YMin | YMax | IsOccluded | IsTruncated | IsGroupOf | IsDepiction | IsInside | XClick1X | XClick2X | XClick3X | XClick4X | XClick1Y | XClick2Y | XClick3Y | XClick4Y | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0000b9fcba019d36 | xclick | /m/0bt9lr | 1 | 0.165000 | 0.903750 | 0.268333 | 0.998333 | 1 | 1 | 0 | 0 | 0 | 0.636250 | 0.903750 | 0.748750 | 0.165000 | 0.268333 | 0.506667 | 0.998333 | 0.661667 |
| 1 | 0000cb13febe0138 | xclick | /m/0bt9lr | 1 | 0.000000 | 0.651875 | 0.000000 | 0.999062 | 1 | 1 | 0 | 0 | 0 | 0.312500 | 0.000000 | 0.317500 | 0.651875 | 0.000000 | 0.410882 | 0.999062 | 0.999062 |
| 2 | 0005a9520eb22c19 | xclick | /m/0bt9lr | 1 | 0.094167 | 0.611667 | 0.055626 | 0.998736 | 1 | 1 | 0 | 0 | 0 | 0.487500 | 0.611667 | 0.243333 | 0.094167 | 0.055626 | 0.226296 | 0.998736 | 0.305942 |
| 3 | 0006303f02219b07 | xclick | /m/0bt9lr | 1 | 0.000000 | 0.999219 | 0.000000 | 0.998824 | 1 | 1 | 0 | 0 | 0 | 0.508594 | 0.999219 | 0.000000 | 0.478906 | 0.000000 | 0.375294 | 0.720000 | 0.998824 |
| 4 | 00064d23bf997652 | xclick | /m/0bt9lr | 1 | 0.240938 | 0.906183 | 0.000000 | 0.694286 | 0 | 0 | 0 | 0 | 0 | 0.678038 | 0.906183 | 0.240938 | 0.522388 | 0.000000 | 0.370000 | 0.424286 | 0.694286 |
df.head()
| ImageID | Source | LabelName | Confidence | XMin | XMax | YMin | YMax | IsOccluded | IsTruncated | IsGroupOf | IsDepiction | IsInside | XClick1X | XClick2X | XClick3X | XClick4X | XClick1Y | XClick2Y | XClick3Y | XClick4Y | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0000b9fcba019d36 | xclick | /m/0bt9lr | 1 | 0.165000 | 0.903750 | 0.268333 | 0.998333 | 1 | 1 | 0 | 0 | 0 | 0.636250 | 0.903750 | 0.748750 | 0.165000 | 0.268333 | 0.506667 | 0.998333 | 0.661667 |
| 1 | 0000cb13febe0138 | xclick | /m/0bt9lr | 1 | 0.000000 | 0.651875 | 0.000000 | 0.999062 | 1 | 1 | 0 | 0 | 0 | 0.312500 | 0.000000 | 0.317500 | 0.651875 | 0.000000 | 0.410882 | 0.999062 | 0.999062 |
| 2 | 0005a9520eb22c19 | xclick | /m/0bt9lr | 1 | 0.094167 | 0.611667 | 0.055626 | 0.998736 | 1 | 1 | 0 | 0 | 0 | 0.487500 | 0.611667 | 0.243333 | 0.094167 | 0.055626 | 0.226296 | 0.998736 | 0.305942 |
| 3 | 0006303f02219b07 | xclick | /m/0bt9lr | 1 | 0.000000 | 0.999219 | 0.000000 | 0.998824 | 1 | 1 | 0 | 0 | 0 | 0.508594 | 0.999219 | 0.000000 | 0.478906 | 0.000000 | 0.375294 | 0.720000 | 0.998824 |
| 4 | 00064d23bf997652 | xclick | /m/0bt9lr | 1 | 0.240938 | 0.906183 | 0.000000 | 0.694286 | 0 | 0 | 0 | 0 | 0 | 0.678038 | 0.906183 | 0.240938 | 0.522388 | 0.000000 | 0.370000 | 0.424286 | 0.694286 |
df.columns
Index(['ImageID', 'Source', 'LabelName', 'Confidence', 'XMin', 'XMax', 'YMin',
'YMax', 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction',
'IsInside', 'XClick1X', 'XClick2X', 'XClick3X', 'XClick4X', 'XClick1Y',
'XClick2Y', 'XClick3Y', 'XClick4Y'],
dtype='object')
df.LabelName.replace({'/m/01yrx':'cat', '/m/0bt9lr':'dog'}, inplace=True)
img_shape = []
img_size = np.zeros((df.shape[0], 1))
for i,f in enumerate(tqdm(glob.glob1(path, '*.jpg'))):
file = path+'/'+f
img = Image.open(file)
img_shape.append(f"{img.size[0]}x{img.size[1]}")
img_size[i] += os.path.getsize(file)
!mkdir -p images/resized
# from torch.utils.data import DataLoader, Dataset
# import torchvision.transforms as T
# import torch
# import torch.nn as nn
# from torchvision.utils import make_grid
# from torchvision.utils import save_image
# # from IPython.display import Image
# import matplotlib.pyplot as plt
# import numpy as np
# import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn import metrics
from sklearn import decomposition
from sklearn import manifold
import matplotlib.pyplot as plt
import numpy as np
import copy
import random
import time
# resize image and save, convert to numpy
img_arr = np.zeros((df.shape[0],32*32*3)) # initialize np.array
# mean=[]
# std=[]
for i, f in enumerate(tqdm(df.ImageID)):
img = Image.open(path+f+'.jpg')
img_resized = img.resize((32,32))
img_resized.save("images/resized/"+f+'.jpg', "JPEG", optimize=True)
img_arr[i] = np.asarray(img_resized, dtype=np.uint8).flatten()
# mean.append(np.asarray(img_resized, dtype=np.uint8).flatten().mean()/255)
# std.append(np.asarray(img_resized, dtype=np.uint8).flatten().std()/255)
# encode labels
df['Label'] = (df.LabelName == 'dog').astype(np.uint8)
df.columns
Index(['ImageID', 'Source', 'LabelName', 'Confidence', 'XMin', 'XMax', 'YMin',
'YMax', 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction',
'IsInside', 'XClick1X', 'XClick2X', 'XClick3X', 'XClick4X', 'XClick1Y',
'XClick2Y', 'XClick3Y', 'XClick4Y', 'Label'],
dtype='object')
#spliting the data into train test data
import numpy as np
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(img_arr, df['Label'], test_size=0.33, random_state=42)
#Normalizing the pixel values by substracting with mean and didviding with standard devaiation
mean_train=[]
std_train=[]
for i in range(len(X_train)):
#appending mean of every image to mean_train list
mean_train.append(np.array(X_train[i], dtype=np.uint8).flatten().mean()/255)
#appending mean of every image to std_train list
std_train.append(np.array(X_train[i], dtype=np.uint8).flatten().std()/255)
mean_test=[]
std_test=[]
for i in range(len(X_test)):
mean_test.append(np.array(X_test[i], dtype=np.uint8).flatten().mean()/255)
std_test.append(np.array(X_test[i], dtype=np.uint8).flatten().std()/255)
#normalizing part 2
my_train_data=[]
my_test_data=[]
for i in range(len(X_train)):
my_train_data.append(
(
torch.tensor(
np.round(
(((X_train[i]/255)-mean_train[i])/std_train[i]),5
).astype(np.float32)
),
int(y_train.tolist()[i])
)
)
for i in range(len(X_test)):
my_test_data.append(
(
torch.tensor(
np.round(
(((X_test[i]/255)-mean_test[i])/std_test[i]),5
).astype(np.float32)
),
int(y_test.tolist()[i])
)
)
#spliting the train data into train validation split
import numpy as np
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(my_train_data, y_train, test_size=0.33, random_state=42)
train_data=X_train
valid_data=X_val
test_data=my_test_data
#visualizing the sample sizes
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')
Number of training examples: 5820 Number of validation examples: 2867 Number of testing examples: 4279
#visualising one tensor
X_train[0]
(tensor([-0.9506, -1.4124, -1.2007, ..., -0.4117, -0.9121, -0.7774]), 1)
To ensure we get reproducible results we set the random seed for Python, Numpy and PyTorch.
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
Now we have defined our transforms we can then load the train and test data with the relevant transforms defined.
(my_train_data[0])
(tensor([ 0.4145, 0.4442, -0.1951, ..., 0.0279, 0.1915, -0.4032]), 1)
Next, we'll define a DataLoader for each of the training/validation/test sets. We can iterate over these and they will yield batches of images and labels which we can use to train our model.
We only need to shuffle our training set as it will be used for stochastic gradient descent and we want the each batch to be different between epochs. As we aren't using the validation or test sets to update our model parameters they do not need to be shuffled.
Ideally, we want to use the biggest batch size that we can. The 64 here is relatively small and can be increased if our hardware can handle it.
#In order to take a batch of 500 images at once , we are making use of data loader for train, test and valid data
BATCH_SIZE = 500
train_iterator = data.DataLoader(train_data,
shuffle = True,
batch_size = BATCH_SIZE)
valid_iterator = data.DataLoader(valid_data,
batch_size = BATCH_SIZE)
test_iterator = data.DataLoader(test_data,
batch_size = BATCH_SIZE)
#Model: Our model has four layers input dim=3072, hidden layer one=250, hidden layer two =100, output layer=2
class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.input_fc = nn.Linear(input_dim, 250)
self.hidden_fc = nn.Linear(250, 100)
self.output_fc = nn.Linear(100, output_dim)
#self.output_fc_c = nn.Linear(100, 1)
#self.output_fc_r = nn.Linear(100, 4)
def forward(self, x):
#x = [batch size, height, width]
batch_size = x.shape[0]
x = x.view(batch_size, -1)
#x = [batch size, height * width]
h_1 = F.relu(self.input_fc(x))
#h_1 = [batch size, 250]
h_2 = F.relu(self.hidden_fc(h_1))
#h_2 = [batch size, 100]
y_pred = self.output_fc(h_2)
#y_pred2=self.output_fc_r(h_2)
#y_pred = [batch size, output dim]
return y_pred, h_2
We'll define our model by creating an instance of it and setting the correct input and output dimensions.
#defining the input and out put dimesnsions
INPUT_DIM = 32 * 32 *3
OUTPUT_DIM = 2
#creating an object for classification model
model = MLP(INPUT_DIM, OUTPUT_DIM)
We can also create a small function to calculate the number of trainable parameters (weights and biases) in our model - in case all of our parameters are trainable.
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
The first layer has 3072 neurons connected to 250 neurons, so 3072*250 weighted connections plus 250 bias terms.
The second layer has 250 neurons connected to 100 neurons, 250*100 weighted connections plus 100 bias terms.
The third layer has 100 neurons connected to 2 neurons, 100*2 weighted connections plus 10 bias terms.
$$3072 \cdot 250 + 250 + 250 \cdot 100 + 100 + 100 \cdot 2 + 2= 222,360 $$print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 793,552 trainable parameters
Next, we'll define our optimizer. This is the algorithm we will use to update the parameters of our model with respect to the loss calculated on the data.
We aren't going to go into too much detail on how neural networks are trained (see this article if you want to know how) but the gist is:
We use the Adam algorithm with the default parameters to update our model. Improved results could be obtained by searching over different optimizers and learning rates, however default Adam is usually a good starting off point. Check out this article if you want to learn more about the different optimization algorithms commonly used for neural networks.
optimizer = optim.Adam(model.parameters())
Then, we define a criterion, PyTorch's name for a loss/cost/error function. This function will take in your model's predictions with the actual labels and then compute the loss/cost/error of your model with its current parameters.
CrossEntropyLoss both computes the softmax activation function on the supplied predictions as well as the actual loss via negative log likelihood.
Briefly, the softmax function is:
$$\text{softmax }(\mathbf{x}) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$This turns out 10 dimensional output, where each element is an unbounded real number, into a probability distribution over 10 elements. That is, all values are between 0 and 1, and together they all sum to 1.
Why do we turn things into a probability distribution? So we can use negative log likelihood for our loss function as it expects probabilities. PyTorch calculates negative log likelihood for a single example via:
$$\text{negative log likelihood }(\mathbf{\hat{y}}, y) = -\log \big( \text{softmax}(\mathbf{\hat{y}})[y] \big)$$$\mathbf{\hat{y}}$ is the $\mathbb{R}^{10}$ output, from our neural network, whereas $y$ is the label, an integer representing the class. The loss is the negative log of the class index of the softmax. For example:
$$\mathbf{\hat{y}} = [5,1,1,1,1,1,1,1,1,1]$$$$\text{softmax }(\mathbf{\hat{y}}) = [0.8585, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157]$$If the label was class zero, the loss would be:
$$\text{negative log likelihood }(\mathbf{\hat{y}}, 0) = - \log(0.8585) = 0.153 \dots$$If the label was class five, the loss would be:
$$\text{negative log likelihood }(\mathbf{\hat{y}}, 5) = - \log(0.0157) = 4.154 \dots$$So, intuitively, as your model's output corresponding to the correct class index increases your loss decreases.
#tTaking Cross Entrphy Loss
criterion = nn.CrossEntropyLoss()
We then define device. This is used to place your model and data on to a GPU, if you have one.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
We place our model and criterion on to the device by using the .to method.
model = model.to(device)
criterion = criterion.to(device)
Next, we'll define a function to calculate the accuracy of our model. This takes the index of the highest value for your prediction and compares it against the actual class label. We then divide how many our model got correct by the amount in the batch to calculate accuracy across the batch.
def calculate_accuracy(y_pred, y):
top_pred = y_pred.argmax(1, keepdim = True)
correct = top_pred.eq(y.view_as(top_pred)).sum()
acc = correct.float() / y.shape[0]
return acc
We finally define our training loop.
This will:
train modex, through to model to get predictions, y_predSome layers act differently when training and evaluating the model that contains them, hence why we must tell our model we are in "training" mode. The model we are using here does not use any of those layers, however it is good practice to get used to putting your model in training mode.
#Training the model by passing model, train iterator to iterator, optimser as adams, criterion as CXE
def train(model, iterator, optimizer, criterion, device):
epoch_loss = 0
epoch_acc = 0
model.train()
for (x, y) in iterator:
x = x.to(device)
y = y.to(device)
#y2= 4 elements for reg
optimizer.zero_grad()
y_pred, _ = model(x)
loss = criterion(y_pred, y)
#loss2=(ypred2,y2)
# print(loss+loss,"\n",loss.item())
acc = calculate_accuracy(y_pred, y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
The evaluation loop is similar to the training loop. The differences are:
model.eval()with torch.no_grad()torch.no_grad() ensures that gradients are not calculated for whatever is inside the with block. As our model will not have to calculate gradients it will be faster and use less memory.
#We stop training and start evaluating the model
def evaluate(model, iterator, criterion, device):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for (x, y) in iterator:
x = x.to(device)
y = y.to(device)
y_pred, _ = model(x.float())
loss = criterion(y_pred, y)
acc = calculate_accuracy(y_pred, y)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
The final step before training is to define a small function to tell us how long an epoch took.
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
We're finally ready to train!
During each epoch we calculate the training loss and accuracy, followed by the validation loss and accuracy. We then check if the validation loss achieved is the best validation loss we have seen. If so, we save our model's parameters (called a state_dict).
#This is where we call the train and evaluate functions to actually evaluate the model
EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
start_time = time.monotonic()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut1-model.pt')
end_time = time.monotonic()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
Epoch: 01 | Epoch Time: 0m 0s Train Loss: 0.693 | Train Acc: 55.41% Val. Loss: 0.695 | Val. Acc: 55.53% Epoch: 02 | Epoch Time: 0m 0s Train Loss: 0.647 | Train Acc: 61.86% Val. Loss: 0.690 | Val. Acc: 56.10% Epoch: 03 | Epoch Time: 0m 0s Train Loss: 0.603 | Train Acc: 67.14% Val. Loss: 0.689 | Val. Acc: 56.72% Epoch: 04 | Epoch Time: 0m 0s Train Loss: 0.552 | Train Acc: 72.49% Val. Loss: 0.703 | Val. Acc: 56.93% Epoch: 05 | Epoch Time: 0m 0s Train Loss: 0.494 | Train Acc: 77.36% Val. Loss: 0.751 | Val. Acc: 56.86% Epoch: 06 | Epoch Time: 0m 0s Train Loss: 0.440 | Train Acc: 80.29% Val. Loss: 0.774 | Val. Acc: 55.93% Epoch: 07 | Epoch Time: 0m 0s Train Loss: 0.365 | Train Acc: 85.55% Val. Loss: 0.843 | Val. Acc: 54.54% Epoch: 08 | Epoch Time: 0m 0s Train Loss: 0.302 | Train Acc: 89.53% Val. Loss: 0.893 | Val. Acc: 55.30% Epoch: 09 | Epoch Time: 0m 0s Train Loss: 0.232 | Train Acc: 92.54% Val. Loss: 0.972 | Val. Acc: 56.32% Epoch: 10 | Epoch Time: 0m 0s Train Loss: 0.180 | Train Acc: 94.63% Val. Loss: 1.023 | Val. Acc: 54.16%
Afterwards, we load our the parameters of the model that achieved the best validation loss and then use this to evaluate our model on the test set.
model.load_state_dict(torch.load('tut1-model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion, device)
Our model achieves 55.93 % accuracy on the test set.
This can be improved by tweaking hyperparameters, e.g. number of layers, number of neurons per layer, optimization algorithm used, learning rate, etc.
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
Test Loss: 0.701 | Test Acc: 55.93%
experiment_results = pd.DataFrame(columns=["Model_name","Test_Accuracy","Test_Loss","Parameters"])
from IPython.display import display, HTML
def wrap_df_text(experiment_results):
return display(HTML(experiment_results.to_html().replace("\\n","<br>")))
experiment_results['Parameters'] = experiment_results['Parameters'].str.wrap(30)
exp_name = f"MLP Classification without Drop out"
parameters = f"Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE"
experiment_results.loc[0,:10] = [f"{exp_name}"]+list([np.round(test_acc*100,3),np.round(test_loss,3),parameters])
wrap_df_text(experiment_results)
| Model_name | Test_Accuracy | Test_Loss | Parameters | |
|---|---|---|---|---|
| 0 | MLP Classification without Drop out | 55.934 | 0.701 | Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE |
Now we've trained our model there's a few things we can look at. Most of these are simple exploratory analysis, but they can offer some insights into your model.
An important thing to do is check what examples your model gets wrong and ensure that they're reasonable mistakes.
The function below will return the model's predictions over a given dataset. It will return the inputs (image) the outputs (model predictions) and the ground truth labels.
def get_predictions(model, iterator, device):
model.eval()
images = []
labels = []
probs = []
with torch.no_grad():
for (x, y) in iterator:
x = x.to(device)
y_pred, _ = model(x)
y_prob = F.softmax(y_pred, dim = -1)
top_pred = y_prob.argmax(1, keepdim = True)
images.append(x.cpu())
labels.append(y.cpu())
probs.append(y_prob.cpu())
images = torch.cat(images, dim = 0)
labels = torch.cat(labels, dim = 0)
probs = torch.cat(probs, dim = 0)
return images, labels, probs
We can then get these predictions and, by taking the index of the highest predicted probability, get the predicted labels.
images, labels, probs = get_predictions(model, test_iterator, device)
pred_labels = torch.argmax(probs, 1)
Then, we can make a confusion matrix from our actual labels and our predicted labels.
def plot_confusion_matrix(labels, pred_labels):
fig = plt.figure(figsize = (10, 10));
ax = fig.add_subplot(1, 1, 1);
cm = metrics.confusion_matrix(labels, pred_labels);
cm = metrics.ConfusionMatrixDisplay(cm, display_labels = range(10));
cm.plot(values_format = 'd', cmap = 'Blues', ax = ax)
The results seem reasonable enough, the most confused predictions-actuals are: 3-5 and 2-7.
plot_confusion_matrix(labels, pred_labels)
Next, for each of our examples, we can check if our predicted label matches our actual label.
corrects = torch.eq(labels, pred_labels)
We can then loop through all of the examples over our model's predictions and store all the examples the model got incorrect into an array.
Then, we sort these incorrect examples by how confident they were, with the most confident being first.
incorrect_examples = []
for image, label, prob, correct in zip(images, labels, probs, corrects):
if not correct:
incorrect_examples.append((image, label, prob))
incorrect_examples.sort(reverse = True, key = lambda x: torch.max(x[2], dim = 0).values)
We can then plot the incorrectly predicted images along with how confident they were on the actual label and how confident they were at the incorrect label.
def plot_most_incorrect(incorrect, n_images):
rows = int(np.sqrt(n_images))
cols = int(np.sqrt(n_images))
fig = plt.figure(figsize = (20, 10))
for i in range(rows*cols):
ax = fig.add_subplot(rows, cols, i+1)
image, true_label, probs = incorrect[i]
true_prob = probs[true_label]
incorrect_prob, incorrect_label = torch.max(probs, dim = 0)
ax.imshow(image.view(28, 28).cpu().numpy(), cmap='bone')
ax.set_title(f'true label: {true_label} ({true_prob:.3f})\n' \
f'pred label: {incorrect_label} ({incorrect_prob:.3f})')
ax.axis('off')
fig.subplots_adjust(hspace= 0.5)
# N_IMAGES = 25
# plot_most_incorrect(incorrect_examples, N_IMAGES)
Another thing we can do is get the output and intermediate representations from the model and try to visualize them.
The function below loops through the provided dataset and gets the output from the model and the intermediate representation from the layer before that, the second hidden layer.
def get_representations(model, iterator, device):
model.eval()
outputs = []
intermediates = []
labels = []
with torch.no_grad():
for (x, y) in iterator:
x = x.to(device)
y_pred, h = model(x)
outputs.append(y_pred.cpu())
intermediates.append(h.cpu())
labels.append(y)
outputs = torch.cat(outputs, dim = 0)
intermediates = torch.cat(intermediates, dim = 0)
labels = torch.cat(labels, dim = 0)
return outputs, intermediates, labels
We run the function to get the representations.
outputs, intermediates, labels = get_representations(model, train_iterator, device)
The data we want to visualize is in ten dimensions and 100 dimensions. We want to get this down to two dimensions so we can actually plot it.
The first technique we'll use is PCA (principal component analysis). First, we'll define a function to calculate the PCA of our data and then we'll define a function to plot it.
def get_pca(data, n_components = 2):
pca = decomposition.PCA()
pca.n_components = n_components
pca_data = pca.fit_transform(data)
return pca_data
def plot_representations(data, labels, n_images = None):
if n_images is not None:
data = data[:n_images]
labels = labels[:n_images]
fig = plt.figure(figsize = (10, 10))
ax = fig.add_subplot(111)
scatter = ax.scatter(data[:, 0], data[:, 1], c = labels, cmap = 'tab10')
handles, labels = scatter.legend_elements()
legend = ax.legend(handles = handles, labels = labels)
First, we plot the representations from the ten dimensional output layer, reduced down to two dimensions.
output_pca_data = get_pca(outputs)
plot_representations(output_pca_data, labels)
Next, we'll plot the outputs of the second hidden layer.
The clusters seem similar to the one above. In fact if we rotated the below image anti-clockwise it wouldn't be too far off the PCA of the output representations.
intermediate_pca_data = get_pca(intermediates)
plot_representations(intermediate_pca_data, labels)
An alternative to PCA is t-SNE (t-distributed stochastic neighbor embedding).
This is commonly thought of as being "better" than PCA, although it can be misinterpreted.
def get_tsne(data, n_components = 2, n_images = None):
if n_images is not None:
data = data[:n_images]
tsne = manifold.TSNE(n_components = n_components, random_state = 0)
tsne_data = tsne.fit_transform(data)
return tsne_data
t-SNE is very slow, so we only compute it on a subset of the representations.
The classes look very well separated, and it is possible to use k-NN on this representation to achieve decent accuracy.
N_IMAGES = 5_000
output_tsne_data = get_tsne(outputs, n_images = N_IMAGES)
plot_representations(output_tsne_data, labels, n_images = N_IMAGES)
We plot the intermediate representations on the same subset.
Again, the classes look well separated, though less so than the output representations. This is because these representations are intermediate features that the neural network has extracted and will use them in further layers to weigh up the evidence of what digit is in the image. Hence, in theory, the classes should become more separated the closer we are to the output layer, which is exactly what we see here.
intermediate_tsne_data = get_tsne(intermediates, n_images = N_IMAGES)
plot_representations(intermediate_tsne_data, labels, n_images = N_IMAGES)
Another experiment we can do is try and generate fake digits.
The function below will repeatedly generate random noise and feed it through the model and find the most confidently generated digit for the desired class.
def imagine_digit(model, digit, device, n_iterations = 50_000):
model.eval()
best_prob = 0
best_image = None
with torch.no_grad():
for _ in range(n_iterations):
x = torch.randn(32, 28, 28).to(device)
y_pred, _ = model(x)
preds = F.softmax(y_pred, dim = -1)
_best_prob, index = torch.max(preds[:,digit], dim = 0)
if _best_prob > best_prob:
best_prob = _best_prob
best_image = x[index]
return best_image, best_prob
Finally, we can plot the weights in the first layer of our model.
The hope is that there's maybe one neuron in this first layer that's learned to look for certain patterns in the input and thus has high weight values indicating this pattern. If we then plot these weights we should see these patterns.
def plot_weights(weights, n_weights):
rows = int(np.sqrt(n_weights))
cols = int(np.sqrt(n_weights))
fig = plt.figure(figsize = (20, 10))
for i in range(rows*cols):
ax = fig.add_subplot(rows, cols, i+1)
ax.imshow(weights[i].view(28, 28).cpu().numpy(), cmap = 'bone')
ax.axis('off')
Looking at these weights we see a few of them look like random noise but some of them do have weird patterns within them. These patterns show "ghostly" image looking shapes, but are clearly not images.
In this notebook we have shown:
In the next notebook we'll implement a convolutional neural network (CNN) and evaluate it on the MNIST dataset.
Build another PyTorch model for regression (using a multilayer perceptron (MLP)) with 4 target values [y_1, y_2, y_3, y_4] corresponding to [x, y, w, h] of the bounding box containing the object of interest).
# resize image and save, convert to numpy
img_arr_reg = [] # initialize np.array
for i, f in enumerate(tqdm(df.ImageID)):
img = Image.open(path+f+'.jpg')
img_resized = img.resize((32,32))
img_resized.save("images/resized/"+f+'.jpg', "JPEG", optimize=True)
img_arr_reg.append( np.asarray(img_resized, dtype=np.uint8))
df.columns
Index(['ImageID', 'Source', 'LabelName', 'Confidence', 'XMin', 'XMax', 'YMin',
'YMax', 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsDepiction',
'IsInside', 'XClick1X', 'XClick2X', 'XClick3X', 'XClick4X', 'XClick1Y',
'XClick2Y', 'XClick3Y', 'XClick4Y', 'Label'],
dtype='object')
mkdir -p data
np.save('data/y_bbox.npy', df[['XMin', 'YMin', 'XMax', 'YMax']].values.astype(np.float32))
y_bbox = np.load('data/y_bbox.npy', allow_pickle=True)
#bounding box dimensions which are to be predicted
y_bbox
array([[0.165 , 0.268333, 0.90375 , 0.998333],
[0. , 0. , 0.651875, 0.999062],
[0.094167, 0.055626, 0.611667, 0.998736],
...,
[0.001475, 0.042406, 0.988201, 0.62426 ],
[0. , 0.037523, 0.998125, 0.999062],
[0.148045, 0.07064 , 0.999069, 0.94702 ]], dtype=float32)
#splitting the data into test train splits
import numpy as np
from sklearn.model_selection import train_test_split
X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(np.array(img_arr_reg)/255,y_bbox, test_size=0.33, random_state=42)
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
class Dataset(torch.utils.data.Dataset):
'''
Prepare the Boston dataset for regression
'''
def __init__(self, X, y, scale_data=True):
if not torch.is_tensor(X) and not torch.is_tensor(y):
# Apply scaling if necessary
#if scale_data:
# X = StandardScaler().fit_transform(X)
self.X = torch.from_numpy(X)
self.y = torch.from_numpy(y)
def __len__(self):
return len(self.X)
def __getitem__(self, i):
return self.X[i], self.y[i]
#CREATING A NN with sequential in pytorch with input_dim=3072, hidden layer 1 dim= 250, hidden layer 2 dim=100, output layer:4
class MLP1(nn.Module):
'''
Multilayer Perceptron for regression.
'''
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(3072,250),
nn.ReLU(),
nn.Linear(250, 100),
nn.ReLU(),
nn.Linear(100, 4)
)
def forward(self, x):
'''
Forward pass
'''
return self.layers(x)
mean_loss=[]
#we call the model, train the model , measure the model every epoch by loss
if __name__ == '__main__':
# Set fixed random number seed
torch.manual_seed(42)
# Load Boston dataset
X, y = X_train_r,y_train_r
# Prepare Boston dataset
dataset = Dataset(np.array(X), np.array(y))
trainloader = torch.utils.data.DataLoader(dataset, batch_size=500, shuffle=True, num_workers=1)
# Initialize the MLP
mlp = MLP1()
# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
# Run the training loop
for epoch in range(0, 20): # 5 epochs at maximum
# Print epoch
print(f'Starting epoch {epoch+1}')
# Set current loss value
current_loss = 0.0
train_count,train_loss=0,0
# Iterate over the DataLoader for training data
for i, data in enumerate(trainloader,0):
# Get and prepare inputs
inputs, targets = data
inputs, targets = inputs.float(), targets.float()
#targets = targets.reshape((targets.shape[0], 1))
# Zero the gradients
optimizer.zero_grad()
loss = 0
# Perform forward pass
for num in range(len(inputs)):
outputs=mlp(inputs[num].flatten())
# Compute loss
loss += loss_function(outputs, targets[num])/500
# Perform backward pass
loss.backward()
# Perform optimization
optimizer.step()
# Print statistics
train_loss += loss.item()
#print(train_loss)
train_count+=1
# if i % 10 == 0:
# print('Loss after mini-batch %5d: %.3f' %
# (i + 1, current_loss ))
# mean_loss.append(current_loss)
# current_loss = 0.0
print(train_loss/train_count)
# Process is complete.
print('Training process has finished.')
Starting epoch 1 0.2078897510137823 Starting epoch 2 0.13889307797782952 Starting epoch 3 0.1308961429943641 Starting epoch 4 0.1270570307970047 Starting epoch 5 0.12442695080406135 Starting epoch 6 0.1220286109795173 Starting epoch 7 0.1199258305132389 Starting epoch 8 0.11825432214472029 Starting epoch 9 0.11655832868483332 Starting epoch 10 0.11532921488914225 Starting epoch 11 0.11409979831013414 Starting epoch 12 0.11258601935373412 Starting epoch 13 0.1111256502982643 Starting epoch 14 0.11000728834834364 Starting epoch 15 0.10932950282262431 Starting epoch 16 0.10754976669947307 Starting epoch 17 0.107028609348668 Starting epoch 18 0.10570604250662857 Starting epoch 19 0.10458212697671519 Starting epoch 20 0.10318754965232478 Training process has finished.
#we call the model, test the model , measure the model every epoch by loss
mean_loss_test=[]
# Set fixed random number seed
torch.manual_seed(42)
# Load Boston dataset
X, y = X_test_r,y_test_r
# Prepare Boston dataset
dataset = Dataset(np.array(X), np.array(y))
testloader = torch.utils.data.DataLoader(dataset, batch_size=500, shuffle=True, num_workers=1)
# Initialize the MLP
# mlp = MLP()
# Define the loss function and optimizer
# loss_function = nn.L1Loss()
# optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
# Run the training loop
for epoch in range(0, 1): # 5 epochs at maximum
# Print epoch
print(f'test loss')
# Set current loss value
current_loss = 0.0
count=0
test_loss=0
# Iterate over the DataLoader for training data
for i, data in enumerate(testloader,0):
# Get and prepare inputs
inputs, targets = data
inputs, targets = inputs.float(), targets.float()
#targets = targets.reshape((targets.shape[0], 1))
# Zero the gradients
# optimizer.zero_grad()
loss = 0
# Perform forward pass
for num in range(len(inputs)):
#outputs = []
outputs=mlp(inputs[num].flatten())
# Compute loss
loss += loss_function(outputs, targets[num])/500
# Perform backward pass
# loss.backward()
# Perform optimization
# optimizer.step()
count+=1
# Print statistics
test_loss += loss.item()
#print(loss.item())
# if i % 10 == 0:
# print('Loss after mini-batch %5d: %.3f' %
# (i + 1, current_loss ))
# mean_loss_test.append(current_loss)
# current_loss = 0.0
print(test_loss/count)
test loss 0.10604933318164614
test_loss = test_loss/count
exp_name = f"MLP Regression without Drop out"
parameters = f"Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE"
experiment_results.loc[len(experiment_results)] = [f"{exp_name}"]+list(['NAN',np.round(test_loss,3),parameters])
wrap_df_text(experiment_results)
| Model_name | Test_Accuracy | Test_Loss | Parameters | |
|---|---|---|---|---|
| 0 | MLP Classification without Drop out | 55.934 | 0.701 | Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE |
| 1 | MLP Regression without Drop out | NAN | 0.106 | Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn import metrics
from sklearn import decomposition
from sklearn import manifold
import matplotlib.pyplot as plt
import numpy as np
import copy
import random
import time
# resize image and save, convert to numpy
img_arr = np.zeros((df.shape[0],32*32*3)) # initialize np.array
# mean=[]
# std=[]
for i, f in enumerate(tqdm(df.ImageID)):
img = Image.open(path+f+'.jpg')
img_resized = img.resize((32,32))
img_resized.save("images/resized/"+f+'.jpg', "JPEG", optimize=True)
img_arr[i] = np.asarray(img_resized, dtype=np.uint8).flatten()
# mean.append(np.asarray(img_resized, dtype=np.uint8).flatten().mean()/255)
# std.append(np.asarray(img_resized, dtype=np.uint8).flatten().std()/255)
#spliting the data into train and test
import numpy as np
from sklearn.model_selection import train_test_split
X_train_mh, X_test_mh, y_train_mh, y_test_mh = train_test_split(img_arr, df[['Label', 'XMin', 'XMax', 'YMin','YMax']], test_size=0.33, random_state=42)
y_train_mh['Label'].tolist()[0]
1
#normalizing all the pixels by subtracting with mean and sd of that image
mean_train_mh=[]
std_train_mh=[]
for i in range(len(X_train_mh)):
mean_train_mh.append(np.array(X_train_mh[i], dtype=np.uint8).flatten().mean()/255)
std_train_mh.append(np.array(X_train_mh[i], dtype=np.uint8).flatten().std()/255)
# final_mean_train=np.array(mean_train).mean()
# final_std_train=np.array(std_train).mean()
mean_test_mh=[]
std_test_mh=[]
for i in range(len(X_test_mh)):
mean_test_mh.append(np.array(X_test_mh[i], dtype=np.uint8).flatten().mean()/255)
std_test_mh.append(np.array(X_test_mh[i], dtype=np.uint8).flatten().std()/255)
# final_mean_test=np.array(mean_test).mean()
# final_std_test=np.array(std_test).mean()
#noramilizing the test train data
my_train_data_mh=[]
for i in range(len(X_train_mh)):
my_train_data_mh.append(
(
torch.tensor(
np.round(
(((X_train_mh[i]/255)-mean_train_mh[i])/std_train_mh[i]),5
).astype(np.float32)
),
int(y_train_mh['Label'].tolist()[i]),
torch.tensor([y_train_mh['XMin'].tolist()[i],
y_train_mh['XMax'].tolist()[i],
y_train_mh['YMin'].tolist()[i],
y_train_mh['YMax'].tolist()[i]])
)
)
my_train_data_mh[0]
(tensor([ 0.4145, 0.4442, -0.1951, ..., 0.0279, 0.1915, -0.4032]), 1, tensor([0.0000, 0.7650, 0.1038, 0.9992]))
#normalizing and convert into tensors
my_test_data_mh=[]
for i in range(len(X_test_mh)):
my_test_data_mh.append(
(
torch.tensor(
np.round(
(((X_test_mh[i]/255)-mean_test_mh[i])/std_test_mh[i]),5
).astype(np.float32)
),
int(y_test_mh['Label'].tolist()[i]),
torch.tensor([y_test_mh['XMin'].tolist()[i],
y_test_mh['XMax'].tolist()[i],
y_test_mh['YMin'].tolist()[i],
y_test_mh['YMax'].tolist()[i]])
)
)
my_test_data_mh[0]
(tensor([-0.5221, -0.7394, -1.0290, ..., 1.0711, 0.7271, 0.4012]), 0, tensor([0.4102, 0.9990, 0.1259, 0.9122]))
#splitting the train data into valid and train sets.
import numpy as np
from sklearn.model_selection import train_test_split
X_train_mh, X_val_mh, y_train_mh, y_val_mh = train_test_split(my_train_data_mh, y_train_mh, test_size=0.33, random_state=42)
train_data_mh=X_train_mh
valid_data_mh=X_val_mh
test_data_mh=my_test_data_mh
print(f'Number of training examples: {len(train_data_mh)}')
print(f'Number of validation examples: {len(valid_data_mh)}')
print(f'Number of testing examples: {len(test_data_mh)}')
Number of training examples: 5820 Number of validation examples: 2867 Number of testing examples: 4279
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
#spliting the data into batch of 500 each with the help of data loader
BATCH_SIZE = 500
train_iterator_mh = data.DataLoader(train_data_mh,
shuffle = True,
batch_size = BATCH_SIZE)
valid_iterator_mh = data.DataLoader(valid_data_mh,
batch_size = BATCH_SIZE)
test_iterator_mh = data.DataLoader(test_data_mh,
batch_size = BATCH_SIZE)
#This is where we defined the model for Multihead with the same network dimesnions as above.
# we are predicting both the class as well as a bounding box by giving output from last hidden layer to two different out put layers
class MLP_mh(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.input_fc = nn.Linear(input_dim, 250)
self.hidden_fc = nn.Linear(250, 100)
# self.output_fc = nn.Linear(100, output_dim)
self.output_fc_c = nn.Linear(100, 2)
self.output_fc_r = nn.Linear(100, 4)
def forward(self, x):
#x = [batch size, height, width]
batch_size = x.shape[0]
x = x.view(batch_size, -1)
#x = [batch size, height * width]
h_1 = F.relu(self.input_fc(x))
#h_1 = [batch size, 250]
h_2 = F.relu(self.hidden_fc(h_1))
#h_2 = [batch size, 100]
# y_pred = self.output_fc(h_2)
y_pred_c=self.output_fc_c(h_2)
y_pred_r=self.output_fc_r(h_2)
#y_pred = [batch size, output dim]
return y_pred_c,y_pred_r, h_2
#passing the input and output dimenions
INPUT_DIM = 32 * 32 *3
# OUTPUT_DIM = 2
#creating an object for model
model = MLP_mh(INPUT_DIM)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 793,956 trainable parameters
optimizer = optim.Adam(model.parameters())
model.parameters
<bound method Module.parameters of MLP_mh( (input_fc): Linear(in_features=3072, out_features=250, bias=True) (hidden_fc): Linear(in_features=250, out_features=100, bias=True) (output_fc_c): Linear(in_features=100, out_features=2, bias=True) (output_fc_r): Linear(in_features=100, out_features=4, bias=True) )>
#using two losses, MSE for regrssiona and CXE for classification
criterion_c = nn.CrossEntropyLoss()
criterion_r = nn.MSELoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion_c = criterion_c.to(device)
criterion_r = criterion_r.to(device)
#defining the accuracy function
def calculate_accuracy(y_pred, y):
top_pred = y_pred.argmax(1, keepdim = True)
correct = top_pred.eq(y.view_as(top_pred)).sum()
acc = correct.float() / y.shape[0]
return acc
# we train the model here by passing all the necessary parameters as in the above two models
def train_mh(model, iterator, optimizer, criterion_c,criterion_r, device):
epoch_loss = 0
epoch_acc = 0
model.train()
for (x, y1,y2) in iterator:
x = x.to(device)
y1 = y1.to(device)
y2 = y2.to(device)
#y2= 4 elements for reg
optimizer.zero_grad()
y_pred_c,y_pred_r ,_ = model(x)
loss_c = criterion_c(y_pred_c, y1)
#loss2=(ypred2,y2)
acc = calculate_accuracy(y_pred_c, y1)
loss_r=criterion_c(y_pred_r, y2)
loss=loss_c+loss_r
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
def evaluate_mh(model, iterator, criterion_c,criterion_r, device):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for (x, y1,y2) in iterator:
x = x.to(device)
y1 = y1.to(device)
y2 = y2.to(device)
y_pred_c,y_pred_r, _ = model(x.float())
loss_c = criterion_c(y_pred_c, y1)
loss_r = criterion_r(y_pred_r, y2)
acc = calculate_accuracy(y_pred_c, y1)
loss=loss_c+loss_r
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
#train accuracy for one hidden layer
EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
start_time = time.monotonic()
train_loss, train_acc = train_mh(model, train_iterator_mh, optimizer, criterion_c,criterion_r, device)
valid_loss, valid_acc = evaluate_mh(model, valid_iterator_mh, criterion_c,criterion_r, device)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut1-model.pt')
end_time = time.monotonic()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
Epoch: 01 | Epoch Time: 0m 0s Train Loss: 2.921 | Train Acc: 54.55% Val. Loss: 2.301 | Val. Acc: 55.00% Epoch: 02 | Epoch Time: 0m 0s Train Loss: 2.713 | Train Acc: 61.58% Val. Loss: 2.793 | Val. Acc: 57.44% Epoch: 03 | Epoch Time: 0m 0s Train Loss: 2.662 | Train Acc: 65.16% Val. Loss: 2.253 | Val. Acc: 56.75% Epoch: 04 | Epoch Time: 0m 0s Train Loss: 2.615 | Train Acc: 69.75% Val. Loss: 2.538 | Val. Acc: 57.90% Epoch: 05 | Epoch Time: 0m 0s Train Loss: 2.566 | Train Acc: 73.85% Val. Loss: 2.480 | Val. Acc: 58.09% Epoch: 06 | Epoch Time: 0m 0s Train Loss: 2.512 | Train Acc: 78.27% Val. Loss: 2.627 | Val. Acc: 57.39% Epoch: 07 | Epoch Time: 0m 0s Train Loss: 2.455 | Train Acc: 80.69% Val. Loss: 2.675 | Val. Acc: 56.56% Epoch: 08 | Epoch Time: 0m 0s Train Loss: 2.411 | Train Acc: 83.62% Val. Loss: 2.796 | Val. Acc: 57.87% Epoch: 09 | Epoch Time: 0m 0s Train Loss: 2.349 | Train Acc: 87.08% Val. Loss: 2.894 | Val. Acc: 57.51% Epoch: 10 | Epoch Time: 0m 0s Train Loss: 2.287 | Train Acc: 90.29% Val. Loss: 3.037 | Val. Acc: 57.77%
model.load_state_dict(torch.load('tut1-model.pt'))
test_loss, test_acc = evaluate_mh(model, test_iterator_mh, criterion_c,criterion_r, device)
We have achieved the test accuracy of 56.42%
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
Test Loss: 2.265 | Test Acc: 56.42%
test_acc = test_acc*100
[test_acc,test_loss,parameters]
[56.42190376917521, 2.264857398139106, 'Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE']
exp_name = f"Multi Head MLP Model"
parameters = f"Optimizer : Adam, Hidden Layers : 2, Activation :Relu Loss: CXE+MSE"
experiment_results.loc[len(experiment_results)] = [f"{exp_name}"]+list([np.round(test_acc,3),np.round(test_loss,3),parameters])
wrap_df_text(experiment_results)
| Model_name | Test_Accuracy | Test_Loss | Parameters | |
|---|---|---|---|---|
| 0 | MLP Classification without Drop out | 55.934 | 0.701 | Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE |
| 1 | MLP Regression without Drop out | NAN | 0.106 | Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE |
| 2 | Multi Head MLP Model | 56.422 | 2.265 | Optimizer : Adam, Hidden Layers : 2, Activation :Relu Loss: CXE+MSE |
pip install skorch
Collecting skorch
Downloading skorch-0.11.0-py3-none-any.whl (155 kB)
|████████████████████████████████| 155 kB 5.1 MB/s
Requirement already satisfied: tqdm>=4.14.0 in /usr/local/lib/python3.7/dist-packages (from skorch) (4.62.3)
Requirement already satisfied: tabulate>=0.7.7 in /usr/local/lib/python3.7/dist-packages (from skorch) (0.8.9)
Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.7/dist-packages (from skorch) (1.19.5)
Requirement already satisfied: scikit-learn>=0.19.1 in /usr/local/lib/python3.7/dist-packages (from skorch) (1.0.1)
Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from skorch) (1.4.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.19.1->skorch) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.19.1->skorch) (3.0.0)
Installing collected packages: skorch
Successfully installed skorch-0.11.0
#from skorch libarary we have taken NeuralNetClassifier In ordeer to use gridserachcv. point to note is that data loader is not compatable with gridseraccv
#we define neural net in the MyModuele class and pass this to the neuralnetclassifier along withh other paramaters such as epochs, batch size and so on.
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F
from skorch import NeuralNetClassifier
X, y =img_arr,df['Label']
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Module):
def __init__(self, num_units=250, nonlin=F.relu):
super(MyModule, self).__init__()
self.dense0 = nn.Linear(32*32*3, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 2)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = F.relu(self.dense1(X))
X = F.softmax(self.output(X))
return X
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
batch_size=500,
criterion=nn.CrossEntropyLoss(),
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)
net.fit(X, y)
y_proba = net.predict_proba(X)
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.8410 0.4715 0.8418 1.2026
2 0.8417 0.4715 0.8418 1.0987
3 0.8420 0.4715 0.8418 1.7653
4 0.8420 0.4715 0.8418 1.4498
5 0.8420 0.4715 0.8418 1.0144
6 0.8420 0.4715 0.8418 1.0376
7 0.8420 0.4715 0.8418 1.0246
8 0.8419 0.4715 0.8418 1.0200
9 0.8421 0.4715 0.8418 1.0212
10 0.8419 0.4715 0.8418 1.0485
#making a pipeline and giving the above defined model to the pipeline
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])
pipe.fit(X, y)
y_proba = pipe.predict_proba(X)
Re-initializing module.
Re-initializing criterion.
Re-initializing optimizer.
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.6883 0.5613 0.6853 1.0059
2 0.6846 0.5652 0.6825 0.9948
3 0.6827 0.5813 0.6800 1.0123
4 0.6797 0.5918 0.6780 0.9992
5 0.6776 0.5902 0.6770 1.0090
6 0.6756 0.5933 0.6747 0.9882
7 0.6736 0.5941 0.6730 0.9922
8 0.6721 0.5968 0.6721 1.1870
9 0.6681 0.5875 0.6718 1.1562
10 0.6665 0.6010 0.6696 1.1188
#passing the pipeline and the k fold validation value to the gridserachcv
from sklearn.model_selection import GridSearchCV
params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
# 'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')
gs.fit(X, y)
print(gs.best_score_, gs.best_params_)
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7798 0.5286 0.7573 1.9227
2 0.7708 0.5084 0.6952 1.0440
3 0.7074 0.5286 0.6918 0.8070
4 0.6978 0.5286 0.6921 0.6683
5 0.6935 0.5286 0.6920 0.6746
6 0.6930 0.5286 0.6920 0.6700
7 0.6932 0.5286 0.6921 0.6699
8 0.6920 0.5286 0.6920 0.6680
9 0.6917 0.5286 0.6920 0.6885
10 0.6923 0.5286 0.6920 0.6728
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.8376 0.5361 0.7684 0.7247
2 0.8046 0.5286 0.7846 1.5659
3 0.7828 0.5286 0.7846 1.9883
4 0.7762 0.5350 0.7763 1.1947
5 0.7736 0.5350 0.7781 1.0133
6 0.7796 0.5298 0.7831 1.5373
7 0.7814 0.5338 0.7785 1.6236
8 0.7820 0.5292 0.7839 1.9590
9 0.7768 0.5321 0.7803 1.7727
10 0.7763 0.5327 0.7802 1.8375
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7833 0.5281 0.7843 0.9662
2 0.7822 0.5298 0.7651 0.7128
3 0.7232 0.4534 0.6940 0.6628
4 0.6984 0.4633 0.6933 0.6749
5 0.6948 0.5286 0.6922 0.6631
6 0.6952 0.5286 0.6931 0.6821
7 0.6952 0.5286 0.6930 0.6756
8 0.6934 0.5286 0.6913 0.6691
9 0.6926 0.5286 0.6922 0.6619
10 0.6928 0.5286 0.6890 0.6732
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7854 0.5286 0.7844 0.7595
2 0.7839 0.5286 0.7833 0.8598
3 0.7820 0.5263 0.7478 0.7545
4 0.7711 0.5327 0.7697 0.6853
5 0.7559 0.5379 0.6960 0.6927
6 0.7073 0.5292 0.6929 0.6735
7 0.6971 0.5292 0.6929 0.6696
8 0.6975 0.5292 0.6929 0.6699
9 0.6952 0.5292 0.6929 0.6689
10 0.6926 0.5292 0.6928 0.6694
11 0.6943 0.5286 0.6928 0.6863
12 0.6927 0.5292 0.6927 0.6753
13 0.6935 0.5286 0.6927 0.6639
14 0.6935 0.5292 0.6927 0.6681
15 0.6949 0.5286 0.6926 0.6789
16 0.6937 0.5286 0.6926 0.6601
17 0.6936 0.5292 0.6925 0.6744
18 0.6927 0.5292 0.6925 0.6689
19 0.6931 0.5292 0.6924 0.6639
20 0.6935 0.5292 0.6924 0.6683
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7857 0.5286 0.7846 2.3256
2 0.7843 0.5286 0.7846 2.4656
3 0.7841 0.5286 0.7846 2.4923
4 0.7844 0.5286 0.7846 2.5743
5 0.7842 0.5286 0.7846 2.4746
6 0.7842 0.5286 0.7846 2.2714
7 0.7843 0.5286 0.7845 2.3150
8 0.7840 0.5292 0.7840 1.9866
9 0.7836 0.5292 0.7839 1.8893
10 0.7824 0.5298 0.7832 1.7017
11 0.7775 0.5298 0.7833 1.1684
12 0.7882 0.5593 0.7521 0.9809
13 0.7796 0.5142 0.7864 0.7404
14 0.7674 0.5599 0.7504 0.9317
15 0.7677 0.5604 0.7487 1.0228
16 0.7688 0.5333 0.7784 1.0439
17 0.7738 0.5309 0.7821 1.4602
18 0.7768 0.5425 0.7690 1.6386
19 0.7615 0.5558 0.7563 1.0856
20 0.7569 0.5471 0.7648 1.0159
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7830 0.5408 0.7677 0.9625
2 0.8301 0.4714 0.8419 1.7780
3 0.8424 0.4714 0.8419 1.9779
4 0.8419 0.4714 0.8419 1.9169
5 0.8420 0.4714 0.8419 1.8429
6 0.8421 0.4714 0.8419 1.8707
7 0.8421 0.4714 0.8419 1.9160
8 0.8416 0.4714 0.8419 1.8844
9 0.8426 0.4714 0.8419 1.8989
10 0.8256 0.5321 0.7810 1.6703
11 0.7815 0.5321 0.7811 1.6020
12 0.7833 0.5292 0.7840 1.6963
13 0.7833 0.5292 0.7840 1.5732
14 0.7820 0.5292 0.7839 1.6830
15 0.7760 0.5315 0.7817 1.6328
16 0.7777 0.5304 0.7825 1.5561
17 0.7722 0.5454 0.7665 1.6003
18 0.7679 0.5500 0.7578 1.3997
19 0.7755 0.4997 0.8052 1.3401
20 0.8414 0.4714 0.8419 1.6774
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.8402 0.4714 0.8419 1.7009
2 0.8421 0.4714 0.8419 1.9660
3 0.8422 0.4714 0.8419 1.9978
4 0.8420 0.4714 0.8419 1.8873
5 0.8420 0.4714 0.8419 1.8631
6 0.8421 0.4714 0.8419 1.8674
7 0.8420 0.4714 0.8419 1.7675
8 0.8420 0.4714 0.8419 1.8049
9 0.8421 0.4714 0.8419 1.8091
10 0.8420 0.4714 0.8419 1.6839
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.8401 0.4714 0.8419 2.1247
2 0.8422 0.4714 0.8419 1.9957
3 0.8420 0.4714 0.8419 1.9154
4 0.8418 0.4714 0.8419 1.8425
5 0.8420 0.4714 0.8419 1.9112
6 0.8419 0.4714 0.8419 1.9682
7 0.8418 0.4714 0.8419 1.9526
8 0.8420 0.4714 0.8419 1.8674
9 0.8420 0.4714 0.8419 1.8782
10 0.8420 0.4714 0.8419 1.8734
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7800 0.5286 0.7845 0.9109
2 0.7831 0.5292 0.7814 1.1988
3 0.7877 0.5165 0.7735 0.9203
4 0.7857 0.5286 0.7846 1.1991
5 0.7833 0.5286 0.7844 1.4953
6 0.7849 0.5286 0.7846 1.5401
7 0.7802 0.5269 0.7840 1.3494
8 0.7790 0.5292 0.7838 1.0285
9 0.7763 0.5292 0.7657 0.7832
10 0.7101 0.5281 0.6933 0.8150
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7719 0.4789 0.7004 0.7110
2 0.7028 0.5205 0.6926 0.6683
3 0.7060 0.4714 0.6997 0.6662
4 0.7002 0.4714 0.6994 0.6643
5 0.6988 0.4714 0.6990 0.6744
6 0.6974 0.4696 0.6989 0.6708
7 0.6947 0.4662 0.6982 0.6620
8 0.6966 0.4933 0.6942 0.6710
9 0.6975 0.4714 0.6979 0.6632
10 0.6975 0.4714 0.6976 0.6684
11 0.6956 0.5072 0.6934 0.6737
12 0.6970 0.5101 0.6943 0.6637
13 0.6974 0.4951 0.6939 0.6728
14 0.6918 0.5240 0.6915 0.6809
15 0.6904 0.5228 0.6912 0.6834
16 0.6928 0.5090 0.6922 0.6551
17 0.6897 0.5061 0.6926 0.6707
18 0.6927 0.4737 0.6960 0.6665
19 0.6922 0.4714 0.6960 0.6684
20 0.6876 0.5495 0.6891 0.6714
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7858 0.5286 0.7846 1.6002
2 0.7850 0.5286 0.7846 2.0001
3 0.7843 0.5286 0.7846 2.0577
4 0.7844 0.5286 0.7846 2.0823
5 0.7845 0.5286 0.7846 2.0718
6 0.7844 0.5286 0.7846 1.9521
7 0.7846 0.5286 0.7846 1.9980
8 0.7847 0.5286 0.7846 1.9734
9 0.7844 0.5286 0.7846 1.9378
10 0.7846 0.5286 0.7846 1.9407
11 0.7842 0.5286 0.7846 1.8899
12 0.7846 0.5286 0.7846 1.8110
13 0.7839 0.5286 0.7844 1.6234
14 0.7836 0.5286 0.7844 1.1031
15 0.7831 0.5321 0.7801 1.2371
16 0.7835 0.5286 0.7846 1.4920
17 0.7845 0.5286 0.7846 1.6519
18 0.7845 0.5286 0.7846 1.6300
19 0.7845 0.5286 0.7846 1.5933
20 0.7845 0.5286 0.7846 1.6887
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7886 0.5286 0.7846 0.8908
2 0.7828 0.5286 0.7818 0.9385
3 0.7655 0.5286 0.6941 0.6964
4 0.7000 0.5286 0.6920 0.6667
5 0.6954 0.5286 0.6920 0.6636
6 0.6934 0.5286 0.6920 0.6694
7 0.6920 0.5286 0.6919 0.6675
8 0.6920 0.5286 0.6919 0.6772
9 0.6916 0.5286 0.6919 0.6787
10 0.6925 0.5286 0.6919 0.6702
11 0.6919 0.5286 0.6919 0.6812
12 0.6922 0.5286 0.6918 0.6787
13 0.6916 0.5286 0.6918 0.6718
14 0.6921 0.5286 0.6918 0.6612
15 0.6910 0.5286 0.6918 0.6680
16 0.6911 0.5286 0.6918 0.6658
17 0.6910 0.5286 0.6918 0.6728
18 0.6913 0.5286 0.6918 0.6861
19 0.6920 0.5286 0.6918 0.6535
20 0.6923 0.5286 0.6918 0.6635
0.5424186333487583 {'lr': 0.02, 'max_epochs': 20}
gs_best_score_c = gs.best_score_*100
exp_name = f"MLP Classification with Drop out"
parameters = f"Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20"
experiment_results.loc[len(experiment_results)] = [f"{exp_name}"]+list([np.round(gs_best_score_c,3),"NAN",parameters])
wrap_df_text(experiment_results)
| Model_name | Test_Accuracy | Test_Loss | Parameters | |
|---|---|---|---|---|
| 0 | MLP Classification without Drop out | 55.934 | 0.701 | Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE |
| 1 | MLP Regression without Drop out | NAN | 0.106 | Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE |
| 2 | Multi Head MLP Model | 56.422 | 2.265 | Optimizer : Adam, Hidden Layers : 2, Activation :Relu Loss: CXE+MSE |
| 3 | MLP Classification with Drop out | 54.242 | NAN | Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20 |
#from skorch libarary we have taken NeuralNetClassifier In ordeer to use gridserachcv. point to note is that data loader is not compatable with gridseraccv
#we define neural net in the MyModuele class and pass this to the neuralnetclassifier along withh other paramaters such as epochs, batch size and so on.
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F
from skorch import NeuralNetRegressor
X, y =img_arr,y_bbox
X = X.astype(np.float32)
y = y.astype(np.float32)
class MyModule(nn.Module):
def __init__(self, num_units=250, nonlin=F.relu):
super(MyModule, self).__init__()
self.dense0 = nn.Linear(32*32*3, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 4)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = F.relu(self.dense1(X))
X = F.softmax(self.output(X))
return X
net = NeuralNetRegressor(
MyModule,
max_epochs=10,
lr=0.1,
batch_size=500,
criterion=nn.MSELoss(),
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)
net.fit(X, y)
# y_proba = net.predict_proba(X)
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2408 0.2426 1.4396
2 0.2399 0.2426 1.4705
3 0.2399 0.2426 1.3727
4 0.2304 0.2181 1.6201
5 0.2190 0.2181 1.5804
6 0.2190 0.2181 1.5957
7 0.2190 0.2181 1.5173
8 0.2190 0.2181 1.4691
9 0.2190 0.2181 1.3410
10 0.2190 0.2181 1.3870
<class 'skorch.regressor.NeuralNetRegressor'>[initialized](
module_=MyModule(
(dense0): Linear(in_features=3072, out_features=250, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(dense1): Linear(in_features=250, out_features=10, bias=True)
(output): Linear(in_features=10, out_features=4, bias=True)
),
)
#making a pipeline and giving the above defined model to the pipeline
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])
pipe.fit(X, y)
# y_proba = pipe.predict_proba(X)
Re-initializing module.
Re-initializing criterion.
Re-initializing optimizer.
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2201 0.1681 0.9686
2 0.1433 0.1255 0.9968
3 0.1207 0.1152 0.9740
4 0.1140 0.1113 0.9836
5 0.1111 0.1094 0.9628
6 0.1095 0.1084 0.9624
7 0.1086 0.1076 0.9766
8 0.1078 0.1072 0.9636
9 0.1072 0.1068 0.9680
10 0.1068 0.1065 0.9750
Pipeline(steps=[('scale', StandardScaler()),
('net',
<class 'skorch.regressor.NeuralNetRegressor'>[initialized](
module_=MyModule(
(dense0): Linear(in_features=3072, out_features=250, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
(dense1): Linear(in_features=250, out_features=10, bias=True)
(output): Linear(in_features=10, out_features=4, bias=True)
),
))])
#passing the pipeline to the gridsearchcv
from sklearn.model_selection import GridSearchCV
params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
# 'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='neg_mean_squared_error')
gs.fit(X, y)
print(gs.best_score_, gs.best_params_)
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2231 0.2143 0.6468
2 0.2151 0.1915 0.6668
3 0.2023 0.1346 0.6359
4 0.1559 0.1288 0.6296
5 0.1255 0.1092 0.6326
6 0.1173 0.1123 0.6357
7 0.1119 0.1079 0.6369
8 0.1098 0.1076 0.6353
9 0.1086 0.1065 0.6228
10 0.1082 0.1074 0.6294
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2591 0.2418 0.8673
2 0.2292 0.2107 0.7683
3 0.2209 0.2255 0.9259
4 0.2219 0.2177 0.8717
5 0.2186 0.2184 1.1983
6 0.2187 0.2183 1.5765
7 0.2186 0.2182 1.3955
8 0.2183 0.2181 1.2441
9 0.2192 0.2110 0.8776
10 0.2197 0.2176 0.7795
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2306 0.2184 0.9571
2 0.2171 0.2184 1.2094
3 0.2172 0.2184 1.2071
4 0.2171 0.2184 1.2538
5 0.2171 0.2184 1.2864
6 0.2171 0.2184 1.2530
7 0.2171 0.2184 1.2956
8 0.2171 0.2184 1.2960
9 0.2171 0.2184 1.3945
10 0.2170 0.2184 1.3638
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2628 0.2440 0.8233
2 0.2385 0.2439 0.7221
3 0.2384 0.2439 0.7557
4 0.2386 0.2440 0.7682
5 0.2382 0.2440 0.8028
6 0.2381 0.2440 0.8227
7 0.2382 0.2440 0.8245
8 0.2382 0.2440 0.8221
9 0.2381 0.2440 0.8207
10 0.2382 0.2440 0.8141
11 0.2382 0.2440 0.8231
12 0.2382 0.2440 0.8205
13 0.2381 0.2440 0.8290
14 0.2381 0.2440 0.7695
15 0.2379 0.2439 0.7721
16 0.2377 0.2439 0.7360
17 0.2308 0.1829 0.6903
18 0.1523 0.1239 0.6398
19 0.1313 0.1054 0.6529
20 0.1202 0.1161 0.6402
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2538 0.2418 1.8104
2 0.2391 0.2418 1.8906
3 0.2391 0.2418 1.9417
4 0.2392 0.2418 1.9234
5 0.2392 0.2418 1.8991
6 0.2390 0.2418 1.8682
7 0.2389 0.2418 1.6228
8 0.2379 0.2417 1.2517
9 0.2226 0.2181 0.7377
10 0.2176 0.2183 0.7367
11 0.2169 0.2074 0.6828
12 0.2170 0.2183 0.6551
13 0.2174 0.2183 0.6695
14 0.2160 0.2178 0.6701
15 0.2154 0.1932 0.6648
16 0.2138 0.2180 0.6347
17 0.2142 0.2178 0.6482
18 0.2128 0.2077 0.6610
19 0.2115 0.2104 0.6327
20 0.2075 0.2003 0.6553
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2528 0.2184 1.6406
2 0.2172 0.2184 2.0249
3 0.2172 0.2184 2.1000
4 0.2172 0.2184 2.0991
5 0.2172 0.2184 2.1683
6 0.2171 0.2184 2.1445
7 0.2171 0.2184 2.1187
8 0.2171 0.2184 2.1052
9 0.2171 0.2184 2.1900
10 0.2171 0.2184 2.1855
11 0.2171 0.2184 2.1600
12 0.2171 0.2184 2.1904
13 0.2173 0.2184 2.1488
14 0.2171 0.2184 2.2375
15 0.2171 0.2184 2.1861
16 0.2172 0.2184 2.2793
17 0.2172 0.2184 2.1502
18 0.2172 0.2184 2.2842
19 0.2171 0.2184 2.2757
20 0.2171 0.2184 2.3086
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2314 0.2143 0.8242
2 0.2127 0.2124 0.7757
3 0.2099 0.1633 0.6544
4 0.2041 0.1491 0.6436
5 0.1889 0.1590 0.6337
6 0.1334 0.1056 0.6522
7 0.1111 0.1040 0.6532
8 0.1099 0.1066 0.6394
9 0.1077 0.1041 0.6524
10 0.1076 0.1042 0.6491
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2267 0.2181 0.8445
2 0.2176 0.2182 0.7820
3 0.2160 0.2177 0.7460
4 0.2113 0.1742 0.6729
5 0.2028 0.1399 0.6390
6 0.1849 0.1524 0.6530
7 0.1574 0.1189 0.6470
8 0.1261 0.1065 0.6421
9 0.1148 0.1094 0.6495
10 0.1114 0.1103 0.6503
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2300 0.2099 0.8913
2 0.2154 0.2169 0.6789
3 0.2155 0.2184 0.6985
4 0.2162 0.2184 0.6998
5 0.2138 0.2181 0.6479
6 0.2131 0.2176 0.6655
7 0.2112 0.2116 0.6561
8 0.2092 0.1912 0.6372
9 0.2042 0.2143 0.6549
10 0.1972 0.1931 0.6494
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2734 0.2143 0.6449
2 0.2215 0.2143 0.6648
3 0.2212 0.2143 0.6801
4 0.2210 0.2143 0.6804
5 0.2210 0.2143 0.6875
6 0.2212 0.2143 0.6841
7 0.2211 0.2143 0.6784
8 0.2212 0.2143 0.7125
9 0.2210 0.2143 0.7095
10 0.2210 0.2143 0.6993
11 0.2210 0.2143 0.7093
12 0.2210 0.2143 0.7121
13 0.2210 0.2143 0.7101
14 0.2210 0.2143 0.7161
15 0.2211 0.2143 0.7271
16 0.2210 0.2143 0.7323
17 0.2210 0.2143 0.7441
18 0.2210 0.2143 0.7411
19 0.2210 0.2143 0.7436
20 0.2210 0.2143 0.7391
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2168 0.2038 0.7057
2 0.2089 0.2001 0.6596
3 0.1913 0.1243 0.6509
4 0.1495 0.1893 0.6396
5 0.1248 0.1072 0.6381
6 0.1090 0.1055 0.6485
7 0.1081 0.1054 0.6577
8 0.1076 0.1057 0.6504
9 0.1071 0.1052 0.6454
10 0.1062 0.1051 0.6420
11 0.1065 0.1051 0.7846
12 0.1065 0.1051 0.6490
13 0.1053 0.1059 0.6501
14 0.1062 0.1050 0.6432
15 0.1059 0.1051 0.6403
16 0.1061 0.1051 0.6425
17 0.1058 0.1051 0.6544
18 0.1053 0.1055 0.6632
19 0.1056 0.1051 0.6321
20 0.1056 0.1051 0.6444
epoch train_loss valid_loss dur
------- ------------ ------------ ------
1 0.2279 0.2184 1.5224
2 0.2170 0.2184 1.7007
3 0.2170 0.2184 1.6812
4 0.2172 0.2184 1.6255
5 0.2171 0.2184 1.6176
6 0.2171 0.2184 1.6597
7 0.2171 0.2184 1.7282
8 0.2171 0.2184 1.6030
9 0.2171 0.2184 1.6127
10 0.2171 0.2184 1.6133
11 0.2171 0.2184 1.5649
12 0.2171 0.2184 1.6091
13 0.2171 0.2184 1.5888
14 0.2171 0.2184 1.5932
15 0.2170 0.2184 1.5714
16 0.2171 0.2184 1.6044
17 0.2171 0.2184 1.5626
18 0.2170 0.2184 1.5754
19 0.2171 0.2184 1.6083
20 0.2171 0.2184 1.5679
-0.1359830449024836 {'lr': 0.02, 'max_epochs': 10}
gs.best_score_r = gs.best_score_
exp_name = f"MLP Regression with Drop out"
parameters = f"Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20,Loss:MSE"
experiment_results.loc[len(experiment_results)] = [f"{exp_name}"]+list(["NAN",np.round(gs.best_score_r,3),parameters])
wrap_df_text(experiment_results)
| Model_name | Test_Accuracy | Test_Loss | Parameters | |
|---|---|---|---|---|
| 0 | MLP Classification without Drop out | 55.934 | 0.701 | Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE |
| 1 | MLP Regression without Drop out | NAN | 0.106 | Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE |
| 2 | Multi Head MLP Model | 56.422 | 2.265 | Optimizer : Adam, Hidden Layers : 2, Activation :Relu Loss: CXE+MSE |
| 3 | MLP Classification with Drop out | 54.242 | NAN | Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20 |
| 4 | MLP Regression with Drop out | NAN | -0.136 | Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20,Loss:MSE |
#MLP with three hidden layers with Relu
class MLP_mh_2(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.input_fc = nn.Linear(input_dim, 250)
self.hidden_fc = nn.Linear(250, 150)
self.hidden_fc_2 = nn.Linear(150, 100)
# self.output_fc = nn.Linear(100, output_dim)
self.output_fc_c = nn.Linear(100, 2)
self.output_fc_r = nn.Linear(100, 4)
def forward(self, x):
#x = [batch size, height, width]
batch_size = x.shape[0]
x = x.view(batch_size, -1)
#x = [batch size, height * width]
h_1 = F.relu(self.input_fc(x))
#h_1 = [batch size, 250]
h_2 = F.relu(self.hidden_fc(h_1))
h_3=F.relu(self.hidden_fc_2(h_2))
#h_2 = [batch size, 100]
# y_pred = self.output_fc(h_2)
y_pred_c=self.output_fc_c(h_3)
y_pred_r=self.output_fc_r(h_3)
#y_pred = [batch size, output dim]
return y_pred_c,y_pred_r, h_2
INPUT_DIM = 32 * 32 *3
# OUTPUT_DIM = 2
model_h1 = MLP_mh_2(INPUT_DIM)
model_h1 = model_h1.to(device)
#train accuracy for 2 hidden layers
EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
start_time = time.monotonic()
train_loss, train_acc = train_mh(model_h1, train_iterator_mh, optimizer, criterion_c,criterion_r, device)
valid_loss, valid_acc = evaluate_mh(model_h1, valid_iterator_mh, criterion_c,criterion_r, device)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut1-model.pt')
end_time = time.monotonic()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
Epoch: 01 | Epoch Time: 0m 0s Train Loss: 3.491 | Train Acc: 51.69% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 02 | Epoch Time: 0m 0s Train Loss: 3.492 | Train Acc: 51.82% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 03 | Epoch Time: 0m 0s Train Loss: 3.492 | Train Acc: 51.74% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 04 | Epoch Time: 0m 0s Train Loss: 3.492 | Train Acc: 51.51% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 05 | Epoch Time: 0m 0s Train Loss: 3.491 | Train Acc: 51.79% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 06 | Epoch Time: 0m 0s Train Loss: 3.491 | Train Acc: 51.65% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 07 | Epoch Time: 0m 0s Train Loss: 3.492 | Train Acc: 51.70% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 08 | Epoch Time: 0m 0s Train Loss: 3.491 | Train Acc: 51.79% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 09 | Epoch Time: 0m 0s Train Loss: 3.492 | Train Acc: 51.83% Val. Loss: 1.093 | Val. Acc: 51.35% Epoch: 10 | Epoch Time: 0m 0s Train Loss: 3.491 | Train Acc: 51.69% Val. Loss: 1.093 | Val. Acc: 51.35%
#model_h1.load_state_dict(torch.load('tut1-model.pt'))
test_loss, test_acc = evaluate_mh(model_h1, test_iterator_mh, criterion_c,criterion_r, device)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
Test Loss: 1.090 | Test Acc: 51.71%
exp_name = f"Multihead with 3 hidden layers & Relu"
parameters = f"Optimizer:Adam,Hidden Layers:3,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:10,Loss:CXE+MSE"
experiment_results.loc[len(experiment_results)] = [f"{exp_name}"]+list([np.round(test_acc*100,3),np.round(test_loss,3),parameters])
wrap_df_text(experiment_results)
| Model_name | Test_Accuracy | Test_Loss | Parameters | |
|---|---|---|---|---|
| 0 | MLP Classification without Drop out | 55.934 | 0.701 | Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE |
| 1 | MLP Regression without Drop out | NAN | 0.106 | Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE |
| 2 | Multi Head MLP Model | 56.422 | 2.265 | Optimizer : Adam, Hidden Layers : 2, Activation :Relu Loss: CXE+MSE |
| 3 | MLP Classification with Drop out | 54.242 | NAN | Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20 |
| 4 | MLP Regression with Drop out | NAN | -0.136 | Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20,Loss:MSE |
| 5 | Multihead with 3 hidden layers & Relu | 51.714 | 1.09 | Optimizer:Adam,Hidden Layers:3,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:10,Loss:CXE+MSE |
#MLP with three hidden layers with Relu and change in dimensions of neural net
class MLP_mh_2_leaky(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.input_fc = nn.Linear(input_dim, 1500)
self.hidden_fc = nn.Linear(1500, 770)
self.hidden_fc_1 = nn.Linear(770, 250)
self.hidden_fc_2 = nn.Linear(250, 100)
# self.output_fc = nn.Linear(100, output_dim)
self.output_fc_c = nn.Linear(100, 2)
self.output_fc_r = nn.Linear(100, 4)
def forward(self, x):
#x = [batch size, height, width]
batch_size = x.shape[0]
x = x.view(batch_size, -1)
#x = [batch size, height * width]
h_1 = F.leaky_relu_(self.input_fc(x))
#h_1 = [batch size, 250]
h_2 = F.leaky_relu_(self.hidden_fc(h_1))
h_3=F.leaky_relu_(self.hidden_fc_1(h_2))
h_4=F.leaky_relu_(self.hidden_fc_2(h_3))
#h_2 = [batch size, 100]
# y_pred = self.output_fc(h_2)
y_pred_c=self.output_fc_c(h_4)
y_pred_r=self.output_fc_r(h_4)
#y_pred = [batch size, output dim]
return y_pred_c,y_pred_r, h_2
INPUT_DIM = 32 * 32 *3
# OUTPUT_DIM = 2
model_h1_leaky = MLP_mh_2_leaky(INPUT_DIM)
model_h1_leaky = model_h1_leaky.to(device)
#train accuracy for 2 hidden layers with leaky realu and chaniging the dimensions of neural net
EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
start_time = time.monotonic()
train_loss, train_acc = train_mh(model_h1_leaky, train_iterator_mh, optimizer, criterion_c,criterion_r, device)
valid_loss, valid_acc = evaluate_mh(model_h1_leaky, valid_iterator_mh, criterion_c,criterion_r, device)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut1-model.pt')
end_time = time.monotonic()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
Epoch: 01 | Epoch Time: 0m 3s Train Loss: 3.521 | Train Acc: 53.01% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 02 | Epoch Time: 0m 3s Train Loss: 3.521 | Train Acc: 53.17% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 03 | Epoch Time: 0m 3s Train Loss: 3.521 | Train Acc: 53.18% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 04 | Epoch Time: 0m 3s Train Loss: 3.521 | Train Acc: 53.23% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 05 | Epoch Time: 0m 3s Train Loss: 3.522 | Train Acc: 53.33% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 06 | Epoch Time: 0m 3s Train Loss: 3.522 | Train Acc: 53.26% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 07 | Epoch Time: 0m 3s Train Loss: 3.523 | Train Acc: 53.26% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 08 | Epoch Time: 0m 3s Train Loss: 3.522 | Train Acc: 53.08% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 09 | Epoch Time: 0m 3s Train Loss: 3.521 | Train Acc: 53.14% Val. Loss: 1.189 | Val. Acc: 52.47% Epoch: 10 | Epoch Time: 0m 3s Train Loss: 3.522 | Train Acc: 53.20% Val. Loss: 1.189 | Val. Acc: 52.47%
test_loss, test_acc = evaluate_mh(model_h1, test_iterator_mh, criterion_c,criterion_r, device)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
Test Loss: 1.090 | Test Acc: 51.71%
exp_name = f"Multihead with 3 hidden layers & Leaky Relu"
parameters = f"Optimizer:Adam,Hidden Layers:3,Activation:Leaky Relu,DropoutRate:0.5,lr:0.02,max_epochs:10,Loss:CXE+MSE"
experiment_results.loc[len(experiment_results)] = [f"{exp_name}"]+list([np.round(test_acc*100,3),np.round(test_loss,3),parameters])
wrap_df_text(experiment_results)
| Model_name | Test_Accuracy | Test_Loss | Parameters | |
|---|---|---|---|---|
| 0 | MLP Classification without Drop out | 55.934 | 0.701 | Optimizer:Adam,Activaton:Relu,HiddenLayer:2+,Loss:CXE |
| 1 | MLP Regression without Drop out | NAN | 0.106 | Optimizer:Adam with lr=1e-4,Activaton:Relu,HiddenLayer:2,Loss:MSE |
| 2 | Multi Head MLP Model | 56.422 | 2.265 | Optimizer : Adam, Hidden Layers : 2, Activation :Relu Loss: CXE+MSE |
| 3 | MLP Classification with Drop out | 54.242 | NAN | Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20 |
| 4 | MLP Regression with Drop out | NAN | -0.136 | Optimizer:Adam,Hidden Layers:2,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:20,Loss:MSE |
| 5 | Multihead with 3 hidden layers & Relu | 51.714 | 1.09 | Optimizer:Adam,Hidden Layers:3,Activation:Relu,DropoutRate:0.5,lr:0.02,max_epochs:10,Loss:CXE+MSE |
| 6 | Multihead with 3 hidden layers & Leaky Relu | 51.714 | 1.09 | Optimizer:Adam,Hidden Layers:3,Activation:Leaky Relu,DropoutRate:0.5,lr:0.02,max_epochs:10,Loss:CXE+MSE |